Fix the large majority of clippy warnings
This commit fixes most of the remaining clippy warnings in the codebase. These warnings were the more semantically difficult ones to fix. There are some warnings that remain from the rebase that will be fixed in the upcoming PR.
This commit is contained in:
parent
b6b80b9ffe
commit
d3da8745c8
@ -84,13 +84,6 @@ impl StringValue {
|
||||
}
|
||||
}
|
||||
|
||||
impl StringValue {
|
||||
/// Clones the internal string value.
|
||||
pub fn to_string(&self) -> String {
|
||||
self.0.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for StringValue {
|
||||
type Target = str;
|
||||
|
||||
|
1
clippy.toml
Normal file
1
clippy.toml
Normal file
@ -0,0 +1 @@
|
||||
enum-variant-size-threshold = 1000
|
@ -10,19 +10,16 @@ use tokio::{
|
||||
#[error(no_from)]
|
||||
pub enum Error {
|
||||
#[error(display = "Failed to open the address cache file")]
|
||||
OpenAddressCache(#[error(source)] io::Error),
|
||||
Open(#[error(source)] io::Error),
|
||||
|
||||
#[error(display = "Failed to read the address cache file")]
|
||||
ReadAddressCache(#[error(source)] io::Error),
|
||||
Read(#[error(source)] io::Error),
|
||||
|
||||
#[error(display = "Failed to parse the address cache file")]
|
||||
ParseAddressCache,
|
||||
Parse,
|
||||
|
||||
#[error(display = "Failed to update the address cache file")]
|
||||
WriteAddressCache(#[error(source)] io::Error),
|
||||
|
||||
#[error(display = "The address cache is empty")]
|
||||
EmptyAddressCache,
|
||||
Write(#[error(source)] io::Error),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -68,7 +65,7 @@ impl AddressCache {
|
||||
self.inner.lock().await.address
|
||||
}
|
||||
|
||||
pub async fn set_address(&self, address: SocketAddr) -> io::Result<()> {
|
||||
pub async fn set_address(&self, address: SocketAddr) -> Result<(), Error> {
|
||||
let mut inner = self.inner.lock().await;
|
||||
if address != inner.address {
|
||||
self.save_to_disk(&address).await?;
|
||||
@ -77,17 +74,21 @@ impl AddressCache {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn save_to_disk(&self, address: &SocketAddr) -> io::Result<()> {
|
||||
async fn save_to_disk(&self, address: &SocketAddr) -> Result<(), Error> {
|
||||
let write_path = match self.write_path.as_ref() {
|
||||
Some(write_path) => write_path,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
let mut file = crate::fs::AtomicFile::new(write_path.to_path_buf()).await?;
|
||||
let mut file = crate::fs::AtomicFile::new(write_path.to_path_buf())
|
||||
.await
|
||||
.map_err(Error::Open)?;
|
||||
let mut contents = address.to_string();
|
||||
contents += "\n";
|
||||
file.write_all(contents.as_bytes()).await?;
|
||||
file.finalize().await
|
||||
file.write_all(contents.as_bytes())
|
||||
.await
|
||||
.map_err(Error::Write)?;
|
||||
file.finalize().await.map_err(Error::Write)
|
||||
}
|
||||
}
|
||||
|
||||
@ -103,12 +104,10 @@ impl AddressCacheInner {
|
||||
}
|
||||
|
||||
async fn read_address_file(path: &Path) -> Result<SocketAddr, Error> {
|
||||
let mut file = fs::File::open(path)
|
||||
.await
|
||||
.map_err(Error::OpenAddressCache)?;
|
||||
let mut file = fs::File::open(path).await.map_err(Error::Open)?;
|
||||
let mut address = String::new();
|
||||
file.read_to_string(&mut address)
|
||||
.await
|
||||
.map_err(Error::ReadAddressCache)?;
|
||||
address.trim().parse().map_err(|_| Error::ParseAddressCache)
|
||||
.map_err(Error::Read)?;
|
||||
address.trim().parse().map_err(|_| Error::Parse)
|
||||
}
|
||||
|
@ -304,7 +304,7 @@ impl Service<Uri> for HttpsConnectorWithSni {
|
||||
)
|
||||
.await?;
|
||||
let tls_stream = TlsStream::connect_https(socket, &hostname).await?;
|
||||
Ok::<_, io::Error>(ApiConnection::Direct(tls_stream))
|
||||
Ok::<_, io::Error>(ApiConnection::Direct(Box::new(tls_stream)))
|
||||
}
|
||||
InnerConnectionMode::Proxied(proxy_config) => {
|
||||
let socket = Self::open_socket(
|
||||
@ -320,7 +320,7 @@ impl Service<Uri> for HttpsConnectorWithSni {
|
||||
addr,
|
||||
);
|
||||
let tls_stream = TlsStream::connect_https(proxy, &hostname).await?;
|
||||
Ok(ApiConnection::Proxied(tls_stream))
|
||||
Ok(ApiConnection::Proxied(Box::new(tls_stream)))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -236,7 +236,7 @@ impl Runtime {
|
||||
new_address_callback: impl ApiEndpointUpdateCallback + Send + Sync + 'static,
|
||||
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
|
||||
) -> rest::RequestServiceHandle {
|
||||
let service_handle = rest::RequestService::new(
|
||||
let service_handle = rest::RequestService::spawn(
|
||||
sni_hostname,
|
||||
self.api_availability.handle(),
|
||||
self.address_cache.clone(),
|
||||
|
@ -132,8 +132,8 @@ impl ApiConnectionMode {
|
||||
|
||||
/// Stream that is either a regular TLS stream or TLS via shadowsocks
|
||||
pub enum ApiConnection {
|
||||
Direct(TlsStream<TcpStream>),
|
||||
Proxied(TlsStream<ProxyClientStream<TcpStream>>),
|
||||
Direct(Box<TlsStream<TcpStream>>),
|
||||
Proxied(Box<TlsStream<ProxyClientStream<TcpStream>>>),
|
||||
}
|
||||
|
||||
impl AsyncRead for ApiConnection {
|
||||
|
@ -130,7 +130,7 @@ impl ServerRelayList {
|
||||
) {
|
||||
let openvpn_endpoint_data = openvpn.ports;
|
||||
for mut openvpn_relay in openvpn.relays.into_iter() {
|
||||
openvpn_relay.to_lower();
|
||||
openvpn_relay.convert_to_lowercase();
|
||||
if let Some((country_code, city_code)) = split_location_code(&openvpn_relay.location) {
|
||||
if let Some(country) = countries.get_mut(country_code) {
|
||||
if let Some(city) = country
|
||||
@ -184,7 +184,7 @@ impl ServerRelayList {
|
||||
};
|
||||
|
||||
for mut wireguard_relay in relays {
|
||||
wireguard_relay.relay.to_lower();
|
||||
wireguard_relay.relay.convert_to_lowercase();
|
||||
if let Some((country_code, city_code)) =
|
||||
split_location_code(&wireguard_relay.relay.location)
|
||||
{
|
||||
@ -235,7 +235,7 @@ impl ServerRelayList {
|
||||
} = bridges;
|
||||
|
||||
for mut bridge_relay in relays {
|
||||
bridge_relay.to_lower();
|
||||
bridge_relay.convert_to_lowercase();
|
||||
if let Some((country_code, city_code)) = split_location_code(&bridge_relay.location) {
|
||||
if let Some(country) = countries.get_mut(country_code) {
|
||||
if let Some(city) = country
|
||||
@ -345,7 +345,7 @@ struct Relay {
|
||||
}
|
||||
|
||||
impl Relay {
|
||||
fn to_lower(&mut self) {
|
||||
fn convert_to_lowercase(&mut self) {
|
||||
self.hostname = self.hostname.to_lowercase();
|
||||
self.location = self.location.to_lowercase();
|
||||
}
|
||||
|
@ -130,7 +130,7 @@ impl<
|
||||
> RequestService<T, F>
|
||||
{
|
||||
/// Constructs a new request service.
|
||||
pub async fn new(
|
||||
pub async fn spawn(
|
||||
sni_hostname: Option<String>,
|
||||
api_availability: ApiAvailabilityHandle,
|
||||
address_cache: AddressCache,
|
||||
|
@ -7,26 +7,26 @@ use tokio::{fs, io};
|
||||
#[error(no_from)]
|
||||
pub enum Error {
|
||||
#[error(display = "Failed to get path")]
|
||||
PathError(#[error(source)] mullvad_paths::Error),
|
||||
Path(#[error(source)] mullvad_paths::Error),
|
||||
|
||||
#[error(display = "Failed to remove directory {}", _0)]
|
||||
RemoveDirError(String, #[error(source)] io::Error),
|
||||
RemoveDir(String, #[error(source)] io::Error),
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
#[error(display = "Failed to create directory {}", _0)]
|
||||
CreateDirError(String, #[error(source)] io::Error),
|
||||
CreateDir(String, #[error(source)] io::Error),
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
#[error(display = "Failed to get file type info")]
|
||||
FileTypeError(#[error(source)] io::Error),
|
||||
FileType(#[error(source)] io::Error),
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
#[error(display = "Failed to get dir entry")]
|
||||
FileEntryError(#[error(source)] io::Error),
|
||||
FileEntry(#[error(source)] io::Error),
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
#[error(display = "Failed to read dir entries")]
|
||||
ReadDirError(#[error(source)] io::Error),
|
||||
ReadDir(#[error(source)] io::Error),
|
||||
}
|
||||
|
||||
pub async fn clear_directories() -> Result<(), Error> {
|
||||
@ -35,12 +35,12 @@ pub async fn clear_directories() -> Result<(), Error> {
|
||||
}
|
||||
|
||||
async fn clear_log_directory() -> Result<(), Error> {
|
||||
let log_dir = mullvad_paths::get_log_dir().map_err(Error::PathError)?;
|
||||
let log_dir = mullvad_paths::get_log_dir().map_err(Error::Path)?;
|
||||
clear_directory(&log_dir).await
|
||||
}
|
||||
|
||||
async fn clear_cache_directory() -> Result<(), Error> {
|
||||
let cache_dir = mullvad_paths::cache_dir().map_err(Error::PathError)?;
|
||||
let cache_dir = mullvad_paths::cache_dir().map_err(Error::Path)?;
|
||||
clear_directory(&cache_dir).await
|
||||
}
|
||||
|
||||
@ -49,22 +49,22 @@ async fn clear_directory(path: &Path) -> Result<(), Error> {
|
||||
{
|
||||
fs::remove_dir_all(path)
|
||||
.await
|
||||
.map_err(|e| Error::RemoveDirError(path.display().to_string(), e))?;
|
||||
.map_err(|e| Error::RemoveDir(path.display().to_string(), e))?;
|
||||
fs::create_dir_all(path)
|
||||
.await
|
||||
.map_err(|e| Error::CreateDirError(path.display().to_string(), e))
|
||||
.map_err(|e| Error::CreateDir(path.display().to_string(), e))
|
||||
}
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
let mut dir = fs::read_dir(&path).await.map_err(Error::ReadDirError)?;
|
||||
let mut dir = fs::read_dir(&path).await.map_err(Error::ReadDir)?;
|
||||
|
||||
let mut result = Ok(());
|
||||
|
||||
while let Some(entry) = dir.next_entry().await.map_err(Error::FileEntryError)? {
|
||||
while let Some(entry) = dir.next_entry().await.map_err(Error::FileEntry)? {
|
||||
let entry_type = match entry.file_type().await {
|
||||
Ok(entry_type) => entry_type,
|
||||
Err(error) => {
|
||||
result = result.and(Err(Error::FileTypeError(error)));
|
||||
result = result.and(Err(Error::FileType(error)));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
@ -74,9 +74,8 @@ async fn clear_directory(path: &Path) -> Result<(), Error> {
|
||||
} else {
|
||||
fs::remove_dir_all(entry.path()).await
|
||||
};
|
||||
result = result.and(
|
||||
removal.map_err(|e| Error::RemoveDirError(entry.path().display().to_string(), e)),
|
||||
);
|
||||
result = result
|
||||
.and(removal.map_err(|e| Error::RemoveDir(entry.path().display().to_string(), e)));
|
||||
}
|
||||
result
|
||||
}
|
||||
|
@ -35,10 +35,10 @@ impl CurrentApiCall {
|
||||
}
|
||||
|
||||
pub fn is_validating(&self) -> bool {
|
||||
match &self.current_call {
|
||||
Some(Call::Validation(_)) | Some(Call::OneshotKeyRotation(_)) => true,
|
||||
_ => false,
|
||||
}
|
||||
matches!(
|
||||
&self.current_call,
|
||||
Some(Call::Validation(_)) | Some(Call::OneshotKeyRotation(_))
|
||||
)
|
||||
}
|
||||
|
||||
pub fn is_running_timed_totation(&self) -> bool {
|
||||
@ -51,10 +51,7 @@ impl CurrentApiCall {
|
||||
|
||||
pub fn is_logging_in(&self) -> bool {
|
||||
use Call::*;
|
||||
match &self.current_call {
|
||||
Some(Login(..)) => true,
|
||||
_ => false,
|
||||
}
|
||||
matches!(&self.current_call, Some(Login(..)))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -5,7 +5,7 @@ use nix::sys::signal::{sigaction, SaFlags, SigAction, SigHandler, SigSet, Signal
|
||||
|
||||
use std::{convert::TryFrom, sync::Once};
|
||||
|
||||
const INIT_ONCE: Once = Once::new();
|
||||
static INIT_ONCE: Once = Once::new();
|
||||
|
||||
const FAULT_SIGNALS: [Signal; 5] = [
|
||||
// Access to invalid memory address
|
||||
|
@ -31,7 +31,7 @@ use crate::target_state::PersistentTargetState;
|
||||
use device::{PrivateAccountAndDevice, PrivateDeviceEvent};
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
future::{abortable, AbortHandle, Future},
|
||||
future::{abortable, AbortHandle, Future, LocalBoxFuture},
|
||||
StreamExt,
|
||||
};
|
||||
use mullvad_relay_selector::{
|
||||
@ -385,6 +385,12 @@ pub struct DaemonCommandChannel {
|
||||
receiver: mpsc::UnboundedReceiver<InternalDaemonEvent>,
|
||||
}
|
||||
|
||||
impl Default for DaemonCommandChannel {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl DaemonCommandChannel {
|
||||
pub fn new() -> Self {
|
||||
let (untracked_sender, receiver) = mpsc::unbounded();
|
||||
@ -472,13 +478,13 @@ impl<E> Sender<E> for DaemonEventSender<E>
|
||||
where
|
||||
InternalDaemonEvent: From<E>,
|
||||
{
|
||||
fn send(&self, event: E) -> Result<(), ()> {
|
||||
fn send(&self, event: E) -> Result<(), talpid_core::mpsc::Error> {
|
||||
if let Some(sender) = self.sender.upgrade() {
|
||||
sender
|
||||
.unbounded_send(InternalDaemonEvent::from(event))
|
||||
.map_err(|_| ())
|
||||
.map_err(|_| talpid_core::mpsc::Error::ChannelClosed)
|
||||
} else {
|
||||
Err(())
|
||||
Err(talpid_core::mpsc::Error::ChannelClosed)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -684,7 +690,7 @@ where
|
||||
relay_list_listener.notify_relay_list(relay_list.clone());
|
||||
};
|
||||
|
||||
let mut relay_list_updater = RelayListUpdater::new(
|
||||
let mut relay_list_updater = RelayListUpdater::spawn(
|
||||
relay_selector.clone(),
|
||||
api_handle.clone(),
|
||||
&cache_dir,
|
||||
@ -785,11 +791,11 @@ where
|
||||
|
||||
/// Shuts down the daemon without shutting down the underlying event listener and the shutdown
|
||||
/// callbacks
|
||||
fn shutdown(
|
||||
fn shutdown<'a>(
|
||||
self,
|
||||
) -> (
|
||||
L,
|
||||
Vec<Pin<Box<dyn Future<Output = ()>>>>,
|
||||
Vec<LocalBoxFuture<'a, ()>>,
|
||||
mullvad_api::Runtime,
|
||||
TunnelStateMachineHandle,
|
||||
) {
|
||||
@ -845,11 +851,11 @@ where
|
||||
TunnelStateTransition::Disconnected => TunnelState::Disconnected,
|
||||
TunnelStateTransition::Connecting(endpoint) => TunnelState::Connecting {
|
||||
endpoint,
|
||||
location: self.parameters_generator.get_last_location(),
|
||||
location: self.parameters_generator.get_last_location().await,
|
||||
},
|
||||
TunnelStateTransition::Connected(endpoint) => TunnelState::Connected {
|
||||
endpoint,
|
||||
location: self.parameters_generator.get_last_location(),
|
||||
location: self.parameters_generator.get_last_location().await,
|
||||
},
|
||||
TunnelStateTransition::Disconnecting(after_disconnect) => {
|
||||
TunnelState::Disconnecting(after_disconnect)
|
||||
@ -1184,7 +1190,7 @@ where
|
||||
}
|
||||
Disconnecting(..) => Self::oneshot_send(
|
||||
tx,
|
||||
self.parameters_generator.get_last_location(),
|
||||
self.parameters_generator.get_last_location().await,
|
||||
"current location",
|
||||
),
|
||||
Connected { location, .. } => {
|
||||
@ -1703,7 +1709,8 @@ where
|
||||
Self::oneshot_send(tx, Ok(()), "use_wireguard_nt response");
|
||||
if settings_changed {
|
||||
self.parameters_generator
|
||||
.set_tunnel_options(&self.settings.tunnel_options);
|
||||
.set_tunnel_options(&self.settings.tunnel_options)
|
||||
.await;
|
||||
self.event_listener
|
||||
.notify_settings(self.settings.to_settings());
|
||||
if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() {
|
||||
@ -1854,7 +1861,8 @@ where
|
||||
Self::oneshot_send(tx, Ok(()), "set_openvpn_mssfix response");
|
||||
if settings_changed {
|
||||
self.parameters_generator
|
||||
.set_tunnel_options(&self.settings.tunnel_options);
|
||||
.set_tunnel_options(&self.settings.tunnel_options)
|
||||
.await;
|
||||
self.event_listener
|
||||
.notify_settings(self.settings.to_settings());
|
||||
if self.get_target_tunnel_type() == Some(TunnelType::OpenVpn) {
|
||||
@ -1963,7 +1971,8 @@ where
|
||||
Self::oneshot_send(tx, Ok(()), "set_enable_ipv6 response");
|
||||
if settings_changed {
|
||||
self.parameters_generator
|
||||
.set_tunnel_options(&self.settings.tunnel_options);
|
||||
.set_tunnel_options(&self.settings.tunnel_options)
|
||||
.await;
|
||||
self.event_listener
|
||||
.notify_settings(self.settings.to_settings());
|
||||
log::info!("Initiating tunnel restart because the enable IPv6 setting changed");
|
||||
@ -1991,7 +2000,8 @@ where
|
||||
Self::oneshot_send(tx, Ok(()), "set_quantum_resistant_tunnel response");
|
||||
if settings_changed {
|
||||
self.parameters_generator
|
||||
.set_tunnel_options(&self.settings.tunnel_options);
|
||||
.set_tunnel_options(&self.settings.tunnel_options)
|
||||
.await;
|
||||
self.event_listener
|
||||
.notify_settings(self.settings.to_settings());
|
||||
if self.get_target_tunnel_type() == Some(TunnelType::Wireguard) {
|
||||
@ -2021,7 +2031,8 @@ where
|
||||
let resolvers =
|
||||
dns::addresses_from_options(&settings.tunnel_options.dns_options);
|
||||
self.parameters_generator
|
||||
.set_tunnel_options(&settings.tunnel_options);
|
||||
.set_tunnel_options(&settings.tunnel_options)
|
||||
.await;
|
||||
self.event_listener.notify_settings(settings);
|
||||
self.send_tunnel_command(TunnelCommand::Dns(resolvers));
|
||||
}
|
||||
@ -2044,7 +2055,8 @@ where
|
||||
Self::oneshot_send(tx, Ok(()), "set_wireguard_mtu response");
|
||||
if settings_changed {
|
||||
self.parameters_generator
|
||||
.set_tunnel_options(&self.settings.tunnel_options);
|
||||
.set_tunnel_options(&self.settings.tunnel_options)
|
||||
.await;
|
||||
self.event_listener
|
||||
.notify_settings(self.settings.to_settings());
|
||||
if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() {
|
||||
@ -2086,7 +2098,8 @@ where
|
||||
);
|
||||
}
|
||||
self.parameters_generator
|
||||
.set_tunnel_options(&self.settings.tunnel_options);
|
||||
.set_tunnel_options(&self.settings.tunnel_options)
|
||||
.await;
|
||||
self.event_listener
|
||||
.notify_settings(self.settings.to_settings());
|
||||
}
|
||||
|
@ -151,7 +151,7 @@ impl ManagementService for ManagementServiceImpl {
|
||||
self.send_command_to_daemon(DaemonCommand::GetVersionInfo(tx))?;
|
||||
self.wait_for_result(rx)
|
||||
.await?
|
||||
.ok_or(Status::not_found("no version cache"))
|
||||
.ok_or_else(|| Status::not_found("no version cache"))
|
||||
.map(types::AppVersionInfo::from)
|
||||
.map(Response::new)
|
||||
}
|
||||
|
@ -53,12 +53,12 @@ pub async fn migrate_formats(settings_dir: &Path, settings: &mut serde_json::Val
|
||||
.read(true)
|
||||
.open(path)
|
||||
.await
|
||||
.map_err(Error::ReadHistoryError)?;
|
||||
.map_err(Error::ReadHistory)?;
|
||||
|
||||
let mut bytes = vec![];
|
||||
file.read_to_end(&mut bytes)
|
||||
.await
|
||||
.map_err(Error::ReadHistoryError)?;
|
||||
.map_err(Error::ReadHistory)?;
|
||||
|
||||
if is_format_v3(&bytes) {
|
||||
return Ok(());
|
||||
@ -92,16 +92,16 @@ fn is_format_v3(bytes: &[u8]) -> bool {
|
||||
}
|
||||
|
||||
async fn write_format_v3(mut file: File, token: Option<AccountToken>) -> Result<()> {
|
||||
file.set_len(0).await.map_err(Error::WriteHistoryError)?;
|
||||
file.set_len(0).await.map_err(Error::WriteHistory)?;
|
||||
file.seek(io::SeekFrom::Start(0))
|
||||
.await
|
||||
.map_err(Error::WriteHistoryError)?;
|
||||
.map_err(Error::WriteHistory)?;
|
||||
if let Some(token) = token {
|
||||
file.write_all(token.as_bytes())
|
||||
.await
|
||||
.map_err(Error::WriteHistoryError)?;
|
||||
.map_err(Error::WriteHistory)?;
|
||||
}
|
||||
file.sync_all().await.map_err(Error::WriteHistoryError)
|
||||
file.sync_all().await.map_err(Error::WriteHistory)
|
||||
}
|
||||
|
||||
fn try_format_v2(bytes: &[u8]) -> Result<Option<(AccountToken, serde_json::Value)>> {
|
||||
|
@ -57,31 +57,31 @@ const SETTINGS_FILE: &str = "settings.json";
|
||||
#[error(no_from)]
|
||||
pub enum Error {
|
||||
#[error(display = "Failed to read the settings")]
|
||||
ReadError(#[error(source)] io::Error),
|
||||
Read(#[error(source)] io::Error),
|
||||
|
||||
#[error(display = "Malformed settings")]
|
||||
ParseError(#[error(source)] serde_json::Error),
|
||||
Parse(#[error(source)] serde_json::Error),
|
||||
|
||||
#[error(display = "Unable to read any version of the settings")]
|
||||
NoMatchingVersion,
|
||||
|
||||
#[error(display = "Unable to serialize settings to JSON")]
|
||||
SerializeError(#[error(source)] serde_json::Error),
|
||||
Serialize(#[error(source)] serde_json::Error),
|
||||
|
||||
#[error(display = "Unable to open settings for writing")]
|
||||
OpenError(#[error(source)] io::Error),
|
||||
Open(#[error(source)] io::Error),
|
||||
|
||||
#[error(display = "Unable to write new settings")]
|
||||
WriteError(#[error(source)] io::Error),
|
||||
Write(#[error(source)] io::Error),
|
||||
|
||||
#[error(display = "Unable to sync settings to disk")]
|
||||
SyncError(#[error(source)] io::Error),
|
||||
SyncSettings(#[error(source)] io::Error),
|
||||
|
||||
#[error(display = "Failed to read the account history")]
|
||||
ReadHistoryError(#[error(source)] io::Error),
|
||||
ReadHistory(#[error(source)] io::Error),
|
||||
|
||||
#[error(display = "Failed to write new account history")]
|
||||
WriteHistoryError(#[error(source)] io::Error),
|
||||
WriteHistory(#[error(source)] io::Error),
|
||||
|
||||
#[error(display = "Failed to parse account history")]
|
||||
ParseHistoryError,
|
||||
@ -129,10 +129,10 @@ pub(crate) async fn migrate_all(
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let settings_bytes = fs::read(&path).await.map_err(Error::ReadError)?;
|
||||
let settings_bytes = fs::read(&path).await.map_err(Error::Read)?;
|
||||
|
||||
let mut settings: serde_json::Value =
|
||||
serde_json::from_reader(&settings_bytes[..]).map_err(Error::ParseError)?;
|
||||
serde_json::from_reader(&settings_bytes[..]).map_err(Error::Parse)?;
|
||||
|
||||
if !settings.is_object() {
|
||||
return Err(Error::NoMatchingVersion);
|
||||
@ -155,7 +155,7 @@ pub(crate) async fn migrate_all(
|
||||
return Ok(migration_data);
|
||||
}
|
||||
|
||||
let buffer = serde_json::to_string_pretty(&settings).map_err(Error::SerializeError)?;
|
||||
let buffer = serde_json::to_string_pretty(&settings).map_err(Error::Serialize)?;
|
||||
|
||||
let mut options = fs::OpenOptions::new();
|
||||
#[cfg(unix)]
|
||||
@ -168,11 +168,11 @@ pub(crate) async fn migrate_all(
|
||||
.truncate(true)
|
||||
.open(&path)
|
||||
.await
|
||||
.map_err(Error::OpenError)?;
|
||||
.map_err(Error::Open)?;
|
||||
file.write_all(&buffer.into_bytes())
|
||||
.await
|
||||
.map_err(Error::WriteError)?;
|
||||
file.sync_data().await.map_err(Error::SyncError)?;
|
||||
.map_err(Error::Write)?;
|
||||
file.sync_data().await.map_err(Error::SyncSettings)?;
|
||||
|
||||
log::debug!("Migrated settings. Wrote settings to {}", path.display());
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
#![allow(clippy::identity_op)]
|
||||
use super::{Error, Result};
|
||||
use mullvad_types::settings::SettingsVersion;
|
||||
use std::time::Duration;
|
||||
|
@ -66,7 +66,7 @@ pub fn migrate(settings: &mut serde_json::Value) -> Result<()> {
|
||||
DnsState::Default
|
||||
};
|
||||
let addresses = if let Some(addrs) = options.get("addresses") {
|
||||
serde_json::from_value(addrs.clone()).map_err(Error::ParseError)?
|
||||
serde_json::from_value(addrs.clone()).map_err(Error::Parse)?
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
@ -43,8 +43,7 @@ pub fn migrate(settings: &mut serde_json::Value) -> Result<()> {
|
||||
if let Some(constraints) = wireguard_constraints {
|
||||
let (port, protocol): (Constraint<u16>, TransportProtocol) =
|
||||
if let Some(port) = constraints.get("port") {
|
||||
let port_constraint =
|
||||
serde_json::from_value(port.clone()).map_err(Error::ParseError)?;
|
||||
let port_constraint = serde_json::from_value(port.clone()).map_err(Error::Parse)?;
|
||||
match port_constraint {
|
||||
Constraint::Any => (Constraint::Any, TransportProtocol::Udp),
|
||||
Constraint::Only(port) => (Constraint::Only(port), wg_protocol_from_port(port)),
|
||||
@ -77,13 +76,13 @@ pub fn migrate(settings: &mut serde_json::Value) -> Result<()> {
|
||||
|
||||
if let Some(constraints) = openvpn_constraints {
|
||||
let port: Constraint<u16> = if let Some(port) = constraints.get("port") {
|
||||
serde_json::from_value(port.clone()).map_err(Error::ParseError)?
|
||||
serde_json::from_value(port.clone()).map_err(Error::Parse)?
|
||||
} else {
|
||||
Constraint::Any
|
||||
};
|
||||
let transport_constraint: Constraint<TransportProtocol> =
|
||||
if let Some(protocol) = constraints.get("protocol") {
|
||||
serde_json::from_value(protocol.clone()).map_err(Error::ParseError)?
|
||||
serde_json::from_value(protocol.clone()).map_err(Error::Parse)?
|
||||
} else {
|
||||
Constraint::Any
|
||||
};
|
||||
|
@ -95,7 +95,7 @@ pub(crate) async fn migrate(settings: &mut serde_json::Value) -> Result<Option<M
|
||||
//
|
||||
if let Some(port) = wireguard_constraints.get("port") {
|
||||
let port_constraint: Constraint<TransportPort> =
|
||||
serde_json::from_value(port.clone()).map_err(Error::ParseError)?;
|
||||
serde_json::from_value(port.clone()).map_err(Error::Parse)?;
|
||||
if let Some(transport_port) = port_constraint.option() {
|
||||
let (port, obfuscation_settings) = match transport_port.protocol {
|
||||
TransportProtocol::Udp => (serde_json::json!(transport_port.port), None),
|
||||
@ -116,8 +116,7 @@ pub(crate) async fn migrate(settings: &mut serde_json::Value) -> Result<Option<M
|
||||
|
||||
let migration_data = if let Some(token) = settings.get("account_token").filter(|t| !t.is_null())
|
||||
{
|
||||
let token: AccountToken =
|
||||
serde_json::from_value(token.clone()).map_err(Error::ParseError)?;
|
||||
let token: AccountToken = serde_json::from_value(token.clone()).map_err(Error::Parse)?;
|
||||
let migration_data =
|
||||
if let Some(wg_data) = settings.get("wireguard").filter(|wg| !wg.is_null()) {
|
||||
Some(MigrationData {
|
||||
|
@ -1,8 +1,6 @@
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
use std::{future::Future, pin::Pin, sync::Arc};
|
||||
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use mullvad_relay_selector::{RelaySelector, SelectedBridge, SelectedObfuscator, SelectedRelay};
|
||||
use mullvad_types::{
|
||||
@ -32,7 +30,7 @@ pub enum Error {
|
||||
NoBridgeAvailable,
|
||||
|
||||
#[error(display = "Failed to resolve hostname for custom relay")]
|
||||
ResolveCustomHostnameError,
|
||||
ResolveCustomHostname,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -65,13 +63,13 @@ impl ParametersGenerator {
|
||||
}
|
||||
|
||||
/// Sets the tunnel options to use when generating new tunnel parameters.
|
||||
pub fn set_tunnel_options(&self, tunnel_options: &TunnelOptions) {
|
||||
self.0.lock().unwrap().tunnel_options = tunnel_options.clone();
|
||||
pub async fn set_tunnel_options(&self, tunnel_options: &TunnelOptions) {
|
||||
self.0.lock().await.tunnel_options = tunnel_options.clone();
|
||||
}
|
||||
|
||||
/// Gets the location associated with the last generated tunnel parameters.
|
||||
pub fn get_last_location(&self) -> Option<GeoIpLocation> {
|
||||
let inner = self.0.lock().unwrap();
|
||||
pub async fn get_last_location(&self) -> Option<GeoIpLocation> {
|
||||
let inner = self.0.lock().await;
|
||||
|
||||
let relays = inner.last_generated_relays.as_ref()?;
|
||||
|
||||
@ -131,7 +129,7 @@ impl InnerParametersGenerator {
|
||||
.to_tunnel_parameters(self.tunnel_options.clone(), None)
|
||||
.map_err(|e| {
|
||||
log::error!("Failed to resolve hostname for custom tunnel config: {}", e);
|
||||
Error::ResolveCustomHostnameError
|
||||
Error::ResolveCustomHostname
|
||||
})
|
||||
}
|
||||
Ok((SelectedRelay::Normal(constraints), bridge, obfuscator)) => {
|
||||
@ -246,13 +244,13 @@ impl TunnelParametersGenerator for ParametersGenerator {
|
||||
) -> Pin<Box<dyn Future<Output = Result<TunnelParameters, ParameterGenerationError>>>> {
|
||||
let generator = self.0.clone();
|
||||
Box::pin(async move {
|
||||
let mut inner = generator.lock().unwrap();
|
||||
let mut inner = generator.lock().await;
|
||||
inner
|
||||
.generate(retry_attempt)
|
||||
.await
|
||||
.map_err(|error| match error {
|
||||
Error::NoBridgeAvailable => ParameterGenerationError::NoMatchingBridgeRelay,
|
||||
Error::ResolveCustomHostnameError => {
|
||||
Error::ResolveCustomHostname => {
|
||||
ParameterGenerationError::CustomTunnelHostResultionError
|
||||
}
|
||||
error => {
|
||||
|
@ -14,6 +14,7 @@ use std::{
|
||||
future::Future,
|
||||
io,
|
||||
path::{Path, PathBuf},
|
||||
str::FromStr,
|
||||
time::Duration,
|
||||
};
|
||||
use talpid_core::mpsc::Sender;
|
||||
@ -297,10 +298,10 @@ impl VersionUpdater {
|
||||
if !*IS_DEV_BUILD {
|
||||
let stable_version = latest_stable
|
||||
.as_ref()
|
||||
.and_then(|stable| ParsedAppVersion::from_str(stable));
|
||||
.and_then(|stable| ParsedAppVersion::from_str(stable).ok());
|
||||
|
||||
let beta_version = if show_beta {
|
||||
ParsedAppVersion::from_str(latest_beta)
|
||||
ParsedAppVersion::from_str(latest_beta).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@ -339,10 +340,10 @@ impl VersionUpdater {
|
||||
let mut check_delay = next_delay();
|
||||
let mut version_check = futures::future::Fuse::terminated();
|
||||
|
||||
// If this is a dev build ,there's no need to pester the API for version checks.
|
||||
// If this is a dev build, there's no need to pester the API for version checks.
|
||||
if *IS_DEV_BUILD {
|
||||
log::warn!("Not checking for updates because this is a development build");
|
||||
while let Some(_) = rx.next().await {}
|
||||
while rx.next().await.is_some() {}
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -384,7 +384,7 @@ impl RelaySelector {
|
||||
let mut relay_matcher = RelayMatcher {
|
||||
location: location.clone(),
|
||||
providers: providers.clone(),
|
||||
ownership: ownership.clone(),
|
||||
ownership: *ownership,
|
||||
tunnel: openvpn_constraints,
|
||||
};
|
||||
|
||||
@ -492,7 +492,7 @@ impl RelaySelector {
|
||||
let mut entry_relay_matcher = RelayMatcher {
|
||||
location: location.clone(),
|
||||
providers: providers.clone(),
|
||||
ownership: ownership.clone(),
|
||||
ownership: *ownership,
|
||||
tunnel: wireguard_constraints.clone().into(),
|
||||
};
|
||||
|
||||
@ -532,7 +532,7 @@ impl RelaySelector {
|
||||
.clone(),
|
||||
..matcher.clone()
|
||||
}
|
||||
.to_wireguard_matcher();
|
||||
.into_wireguard_matcher();
|
||||
|
||||
// Pick the entry relay first if its location constraint is a subset of the exit location.
|
||||
if relay_constraints.wireguard_constraints.use_multihop {
|
||||
@ -746,7 +746,7 @@ impl RelaySelector {
|
||||
let bridge_constraints = InternalBridgeConstraints {
|
||||
location: settings.location.clone(),
|
||||
providers: settings.providers.clone(),
|
||||
ownership: settings.ownership.clone(),
|
||||
ownership: settings.ownership,
|
||||
// FIXME: This is temporary while talpid-core only supports TCP proxies
|
||||
transport_protocol: Constraint::Only(TransportProtocol::Tcp),
|
||||
};
|
||||
@ -791,7 +791,7 @@ impl RelaySelector {
|
||||
BridgeSettings::Normal(settings) => InternalBridgeConstraints {
|
||||
location: settings.location.clone(),
|
||||
providers: settings.providers.clone(),
|
||||
ownership: settings.ownership.clone(),
|
||||
ownership: settings.ownership,
|
||||
transport_protocol: Constraint::Only(TransportProtocol::Tcp),
|
||||
},
|
||||
BridgeSettings::Custom(_bridge_settings) => InternalBridgeConstraints {
|
||||
@ -1064,7 +1064,7 @@ impl RelaySelector {
|
||||
let addr_in = endpoint
|
||||
.as_ref()
|
||||
.map(|endpoint| endpoint.to_endpoint().address.ip())
|
||||
.unwrap_or(IpAddr::from(selected_relay.ipv4_addr_in));
|
||||
.unwrap_or_else(|| IpAddr::from(selected_relay.ipv4_addr_in));
|
||||
log::info!("Selected relay {} at {}", selected_relay.hostname, addr_in);
|
||||
endpoint.map(|endpoint| NormalSelectedRelay::new(endpoint, selected_relay.clone()))
|
||||
})
|
||||
|
@ -34,7 +34,7 @@ impl From<RelayConstraints> for RelayMatcher<AnyTunnelMatcher> {
|
||||
}
|
||||
|
||||
impl RelayMatcher<AnyTunnelMatcher> {
|
||||
pub fn to_wireguard_matcher(self) -> RelayMatcher<WireguardMatcher> {
|
||||
pub fn into_wireguard_matcher(self) -> RelayMatcher<WireguardMatcher> {
|
||||
RelayMatcher {
|
||||
tunnel: self.tunnel.wireguard,
|
||||
location: self.location,
|
||||
|
@ -57,7 +57,7 @@ pub struct RelayListUpdater {
|
||||
}
|
||||
|
||||
impl RelayListUpdater {
|
||||
pub fn new(
|
||||
pub fn spawn(
|
||||
selector: super::RelaySelector,
|
||||
api_handle: MullvadRestHandle,
|
||||
cache_dir: &Path,
|
||||
|
@ -2,7 +2,7 @@ use clap::{crate_authors, crate_description, crate_name, App};
|
||||
use mullvad_api::{self, proxy::ApiConnectionMode};
|
||||
use mullvad_management_interface::new_rpc_client;
|
||||
use mullvad_types::version::ParsedAppVersion;
|
||||
use std::{path::PathBuf, process, time::Duration};
|
||||
use std::{path::PathBuf, process, str::FromStr, time::Duration};
|
||||
use talpid_core::{
|
||||
firewall::{self, Firewall},
|
||||
future_retry::{constant_interval, retry_future_n},
|
||||
@ -133,7 +133,7 @@ async fn main() {
|
||||
|
||||
async fn is_older_version(old_version: &str) -> Result<ExitStatus, Error> {
|
||||
let parsed_version =
|
||||
ParsedAppVersion::from_str(old_version).ok_or(Error::ParseVersionStringError)?;
|
||||
ParsedAppVersion::from_str(old_version).map_err(|_| Error::ParseVersionStringError)?;
|
||||
|
||||
Ok(if parsed_version < *APP_VERSION {
|
||||
ExitStatus::Ok
|
||||
@ -152,7 +152,7 @@ async fn prepare_restart() -> Result<(), Error> {
|
||||
|
||||
async fn reset_firewall() -> Result<(), Error> {
|
||||
// Ensure that the daemon isn't running
|
||||
if let Ok(_) = new_rpc_client().await {
|
||||
if new_rpc_client().await.is_ok() {
|
||||
return Err(Error::DaemonIsRunning);
|
||||
}
|
||||
|
||||
|
@ -625,19 +625,15 @@ impl RelaySettingsUpdate {
|
||||
RelaySettingsUpdate::CustomTunnelEndpoint(endpoint) => {
|
||||
endpoint.endpoint().protocol == TransportProtocol::Tcp
|
||||
}
|
||||
RelaySettingsUpdate::Normal(update) => {
|
||||
if let Some(constraints) = &update.openvpn_constraints {
|
||||
!matches!(
|
||||
&constraints.port,
|
||||
Constraint::Only(TransportPort {
|
||||
RelaySettingsUpdate::Normal(update) => !matches!(
|
||||
&update.openvpn_constraints,
|
||||
Some(OpenVpnConstraints {
|
||||
port: Constraint::Only(TransportPort {
|
||||
protocol: TransportProtocol::Udp,
|
||||
..
|
||||
})
|
||||
)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
})
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,10 @@
|
||||
use jnix::IntoJava;
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::{Ord, Ordering, PartialOrd};
|
||||
use std::{
|
||||
cmp::{Ord, Ordering, PartialOrd},
|
||||
str::FromStr,
|
||||
};
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref STABLE_REGEX: Regex = Regex::new(r"^(\d{4})\.(\d+)$").unwrap();
|
||||
@ -44,30 +47,33 @@ pub enum ParsedAppVersion {
|
||||
Dev(u32, u32, Option<u32>, String),
|
||||
}
|
||||
|
||||
impl ParsedAppVersion {
|
||||
pub fn from_str(version: &str) -> Option<Self> {
|
||||
impl FromStr for ParsedAppVersion {
|
||||
type Err = ();
|
||||
fn from_str(version: &str) -> Result<Self, Self::Err> {
|
||||
let get_int = |cap: ®ex::Captures<'_>, idx| cap.get(idx)?.as_str().parse().ok();
|
||||
|
||||
if let Some(caps) = STABLE_REGEX.captures(version) {
|
||||
let year = get_int(&caps, 1)?;
|
||||
let version = get_int(&caps, 2)?;
|
||||
Some(Self::Stable(year, version))
|
||||
let year = get_int(&caps, 1).ok_or(())?;
|
||||
let version = get_int(&caps, 2).ok_or(())?;
|
||||
Ok(Self::Stable(year, version))
|
||||
} else if let Some(caps) = BETA_REGEX.captures(version) {
|
||||
let year = get_int(&caps, 1)?;
|
||||
let version = get_int(&caps, 2)?;
|
||||
let beta_version = get_int(&caps, 3)?;
|
||||
Some(Self::Beta(year, version, beta_version))
|
||||
let year = get_int(&caps, 1).ok_or(())?;
|
||||
let version = get_int(&caps, 2).ok_or(())?;
|
||||
let beta_version = get_int(&caps, 3).ok_or(())?;
|
||||
Ok(Self::Beta(year, version, beta_version))
|
||||
} else if let Some(caps) = DEV_REGEX.captures(version) {
|
||||
let year = get_int(&caps, 1)?;
|
||||
let version = get_int(&caps, 2)?;
|
||||
let year = get_int(&caps, 1).ok_or(())?;
|
||||
let version = get_int(&caps, 2).ok_or(())?;
|
||||
let beta_version = caps.get(4).map(|_| get_int(&caps, 5).unwrap());
|
||||
let dev_hash = caps.get(6)?.as_str().to_string();
|
||||
Some(Self::Dev(year, version, beta_version, dev_hash))
|
||||
let dev_hash = caps.get(6).ok_or(())?.as_str().to_string();
|
||||
Ok(Self::Dev(year, version, beta_version, dev_hash))
|
||||
} else {
|
||||
None
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ParsedAppVersion {
|
||||
pub fn is_dev(&self) -> bool {
|
||||
matches!(self, ParsedAppVersion::Dev(..))
|
||||
}
|
||||
@ -191,7 +197,7 @@ mod test {
|
||||
];
|
||||
|
||||
for (input, expected_output) in tests {
|
||||
assert_eq!(ParsedAppVersion::from_str(input), expected_output,);
|
||||
assert_eq!(ParsedAppVersion::from_str(input).ok(), expected_output,);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -22,16 +22,16 @@ pub enum Error {
|
||||
RunResolvconf(#[error(source)] io::Error),
|
||||
|
||||
#[error(display = "Using 'resolvconf' to add a record failed: {}", stderr)]
|
||||
AddRecordError { stderr: String },
|
||||
AddRecord { stderr: String },
|
||||
|
||||
#[error(display = "Using 'resolvconf' to delete a record failed")]
|
||||
DeleteRecordError,
|
||||
DeleteRecord,
|
||||
|
||||
#[error(display = "Detected dnsmasq is runing and misconfigured")]
|
||||
DnsmasqMisconfigurationError,
|
||||
DnsmasqMisconfiguration,
|
||||
|
||||
#[error(display = "Current /etc/resolv.conf is not generated by resolvconf")]
|
||||
ResolvconfNotInUseError,
|
||||
ResolvconfNotInUse,
|
||||
}
|
||||
|
||||
pub struct Resolvconf {
|
||||
@ -50,15 +50,15 @@ impl Resolvconf {
|
||||
|
||||
// Check if resolvconf is managing DNS by /etc/resolv.conf
|
||||
if !is_dnsmasq_running
|
||||
&& !(Self::check_if_resolvconf_is_symlinked_correctly()
|
||||
|| Self::check_if_resolvconf_was_generated())
|
||||
&& !Self::check_if_resolvconf_is_symlinked_correctly()
|
||||
&& !Self::check_if_resolvconf_was_generated()
|
||||
{
|
||||
return Err(Error::ResolvconfNotInUseError);
|
||||
return Err(Error::ResolvconfNotInUse);
|
||||
}
|
||||
|
||||
// Check if resolvconf can manage DNS via dnsmasq
|
||||
if is_dnsmasq_running && Self::is_dnsmasq_configured_wrong() {
|
||||
return Err(Error::DnsmasqMisconfigurationError);
|
||||
return Err(Error::DnsmasqMisconfiguration);
|
||||
}
|
||||
|
||||
Ok(Resolvconf {
|
||||
@ -94,7 +94,7 @@ impl Resolvconf {
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
|
||||
return Err(Error::AddRecordError { stderr });
|
||||
return Err(Error::AddRecord { stderr });
|
||||
}
|
||||
|
||||
self.record_names.insert(record_name);
|
||||
@ -118,7 +118,7 @@ impl Resolvconf {
|
||||
record_name,
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
result = Err(Error::DeleteRecordError);
|
||||
result = Err(Error::DeleteRecord);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -28,7 +28,7 @@ pub enum Error {
|
||||
ReadResolvConf(&'static str, #[error(source)] io::Error),
|
||||
|
||||
#[error(display = "resolv.conf at {} could not be parsed", _0)]
|
||||
ParseError(&'static str, #[error(source)] resolv_conf::ParseError),
|
||||
Parse(&'static str, #[error(source)] resolv_conf::ParseError),
|
||||
|
||||
#[error(display = "Failed to remove stale resolv.conf backup at {}", _0)]
|
||||
RemoveBackup(&'static str, #[error(source)] io::Error),
|
||||
@ -179,7 +179,7 @@ fn read_config() -> Result<Config> {
|
||||
|
||||
let contents = fs::read_to_string(RESOLV_CONF_PATH)
|
||||
.map_err(|e| Error::ReadResolvConf(RESOLV_CONF_PATH, e))?;
|
||||
let config = Config::parse(&contents).map_err(|e| Error::ParseError(RESOLV_CONF_PATH, e))?;
|
||||
let config = Config::parse(&contents).map_err(|e| Error::Parse(RESOLV_CONF_PATH, e))?;
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
@ -198,8 +198,8 @@ fn restore_from_backup() -> Result<()> {
|
||||
match fs::read_to_string(RESOLV_CONF_BACKUP_PATH) {
|
||||
Ok(backup) => {
|
||||
log::info!("Restoring DNS state from backup");
|
||||
let config = Config::parse(&backup)
|
||||
.map_err(|e| Error::ParseError(RESOLV_CONF_BACKUP_PATH, e))?;
|
||||
let config =
|
||||
Config::parse(&backup).map_err(|e| Error::Parse(RESOLV_CONF_BACKUP_PATH, e))?;
|
||||
|
||||
write_config(&config)?;
|
||||
|
||||
|
@ -1,11 +1,20 @@
|
||||
/// Error type for `Sender` trait.
|
||||
#[derive(err_derive::Error, Debug)]
|
||||
pub enum Error {
|
||||
/// The underlying channel is closed.
|
||||
#[error(display = "Channel is closed")]
|
||||
ChannelClosed,
|
||||
}
|
||||
|
||||
/// Abstraction over any type that can be used similarly to an `std::mpsc::Sender`.
|
||||
pub trait Sender<T> {
|
||||
/// Sends an item over the underlying channel, failing only if the channel is closed.
|
||||
fn send(&self, item: T) -> Result<(), ()>;
|
||||
fn send(&self, item: T) -> Result<(), Error>;
|
||||
}
|
||||
|
||||
impl<E> Sender<E> for futures::channel::mpsc::UnboundedSender<E> {
|
||||
fn send(&self, content: E) -> Result<(), ()> {
|
||||
self.unbounded_send(content).map_err(|_| ())
|
||||
fn send(&self, content: E) -> Result<(), Error> {
|
||||
self.unbounded_send(content)
|
||||
.map_err(|_| Error::ChannelClosed)
|
||||
}
|
||||
}
|
||||
|
@ -183,7 +183,7 @@ fn construct_icmpv4_packet_inner(
|
||||
|
||||
let checksum = internet_checksum::checksum(buffer);
|
||||
(&mut buffer[ICMP_CHECKSUM_OFFSET..])
|
||||
.write(&checksum)
|
||||
.write_all(&checksum)
|
||||
.unwrap();
|
||||
|
||||
true
|
||||
|
@ -87,13 +87,13 @@ pub type Result<T> = std::result::Result<T, Error>;
|
||||
#[error(no_from)]
|
||||
pub enum Error {
|
||||
#[error(display = "Failed to open a netlink connection")]
|
||||
ConnectError(#[error(source)] io::Error),
|
||||
Connect(#[error(source)] io::Error),
|
||||
|
||||
#[error(display = "Failed to bind netlink socket")]
|
||||
BindError(#[error(source)] io::Error),
|
||||
Bind(#[error(source)] io::Error),
|
||||
|
||||
#[error(display = "Netlink error")]
|
||||
NetlinkError(#[error(source)] rtnetlink::Error),
|
||||
Netlink(#[error(source)] rtnetlink::Error),
|
||||
|
||||
#[error(display = "Route without a valid node")]
|
||||
InvalidRoute,
|
||||
@ -108,16 +108,16 @@ pub enum Error {
|
||||
UnknownDeviceIndex(u32),
|
||||
|
||||
#[error(display = "Failed to get a route for the given IP address")]
|
||||
GetRouteError(#[error(source)] rtnetlink::Error),
|
||||
GetRoute(#[error(source)] rtnetlink::Error),
|
||||
|
||||
#[error(display = "No netlink response for route query")]
|
||||
NoRouteError,
|
||||
NoRoute,
|
||||
|
||||
#[error(display = "Route node was malformed")]
|
||||
InvalidRouteNode,
|
||||
|
||||
#[error(display = "No link found")]
|
||||
LinkNotFoundError,
|
||||
LinkNotFound,
|
||||
|
||||
/// Unable to create routing table for tagged connections and packets.
|
||||
#[error(display = "Cannot find a free routing table ID")]
|
||||
@ -140,14 +140,11 @@ pub struct RouteManagerImpl {
|
||||
impl RouteManagerImpl {
|
||||
pub async fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> {
|
||||
let (mut connection, handle, messages) =
|
||||
rtnetlink::new_connection().map_err(Error::ConnectError)?;
|
||||
rtnetlink::new_connection().map_err(Error::Connect)?;
|
||||
|
||||
let mgroup_flags = RTMGRP_IPV4_ROUTE | RTMGRP_IPV6_ROUTE | RTMGRP_LINK | RTMGRP_NOTIFY;
|
||||
let addr = SocketAddr::new(0, mgroup_flags);
|
||||
connection
|
||||
.socket_mut()
|
||||
.bind(&addr)
|
||||
.map_err(Error::BindError)?;
|
||||
connection.socket_mut().bind(&addr).map_err(Error::Bind)?;
|
||||
|
||||
tokio::spawn(connection);
|
||||
|
||||
@ -179,11 +176,11 @@ impl RouteManagerImpl {
|
||||
let mut req = NetlinkMessage::from(RtnlMessage::NewRule((*rule).clone()));
|
||||
req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE;
|
||||
|
||||
let mut response = self.handle.request(req).map_err(Error::NetlinkError)?;
|
||||
let mut response = self.handle.request(req).map_err(Error::Netlink)?;
|
||||
|
||||
while let Some(message) = response.next().await {
|
||||
if let NetlinkPayload::Error(error) = message.payload {
|
||||
return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error)));
|
||||
return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -236,7 +233,7 @@ impl RouteManagerImpl {
|
||||
let mut req = NetlinkMessage::from(RtnlMessage::GetRule(RuleMessage::default()));
|
||||
req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP;
|
||||
|
||||
let mut response = self.handle.request(req).map_err(Error::NetlinkError)?;
|
||||
let mut response = self.handle.request(req).map_err(Error::Netlink)?;
|
||||
|
||||
let mut rules = vec![];
|
||||
|
||||
@ -246,7 +243,7 @@ impl RouteManagerImpl {
|
||||
rules.push(rule);
|
||||
}
|
||||
NetlinkPayload::Error(error) => {
|
||||
return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error)));
|
||||
return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error)));
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
@ -260,12 +257,12 @@ impl RouteManagerImpl {
|
||||
let mut req = NetlinkMessage::from(RtnlMessage::DelRule(rule));
|
||||
req.header.flags = NLM_F_REQUEST | NLM_F_ACK;
|
||||
|
||||
let mut response = self.handle.request(req).map_err(Error::NetlinkError)?;
|
||||
let mut response = self.handle.request(req).map_err(Error::Netlink)?;
|
||||
|
||||
while let Some(message) = response.next().await {
|
||||
if let NetlinkPayload::Error(error) = message.payload {
|
||||
if error.to_io().kind() != io::ErrorKind::NotFound {
|
||||
return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error)));
|
||||
return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -296,7 +293,7 @@ impl RouteManagerImpl {
|
||||
) -> Result<BTreeMap<u32, NetworkInterface>> {
|
||||
let mut link_map = BTreeMap::new();
|
||||
let mut link_request = handle.link().get().execute();
|
||||
while let Some(link) = link_request.try_next().await.map_err(Error::NetlinkError)? {
|
||||
while let Some(link) = link_request.try_next().await.map_err(Error::Netlink)? {
|
||||
if let Some((idx, device)) = Self::map_interface(link) {
|
||||
link_map.insert(idx, device);
|
||||
}
|
||||
@ -543,7 +540,7 @@ impl RouteManagerImpl {
|
||||
|
||||
async fn delete_route_if_exists(&self, route: &Route) -> Result<()> {
|
||||
if let Err(error) = self.delete_route(route).await {
|
||||
if let Error::NetlinkError(rtnetlink::Error::NetlinkError(msg)) = &error {
|
||||
if let Error::Netlink(rtnetlink::Error::NetlinkError(msg)) = &error {
|
||||
if msg.code == -libc::ESRCH {
|
||||
return Ok(());
|
||||
}
|
||||
@ -619,7 +616,7 @@ impl RouteManagerImpl {
|
||||
.del(route_message)
|
||||
.execute()
|
||||
.await
|
||||
.map_err(Error::NetlinkError)
|
||||
.map_err(Error::Netlink)
|
||||
}
|
||||
|
||||
async fn add_route_direct(&mut self, route: Route) -> Result<()> {
|
||||
@ -693,11 +690,11 @@ impl RouteManagerImpl {
|
||||
let mut req = NetlinkMessage::from(RtnlMessage::NewRoute(add_message));
|
||||
req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE;
|
||||
|
||||
let mut response = self.handle.request(req).map_err(Error::NetlinkError)?;
|
||||
let mut response = self.handle.request(req).map_err(Error::Netlink)?;
|
||||
|
||||
while let Some(message) = response.next().await {
|
||||
if let NetlinkPayload::Error(err) = message.payload {
|
||||
return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(err)));
|
||||
return Err(Error::Netlink(rtnetlink::Error::NetlinkError(err)));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
@ -759,7 +756,7 @@ impl RouteManagerImpl {
|
||||
}
|
||||
None => {
|
||||
log::error!("No route detected when assigning the mtu to the Wireguard tunnel");
|
||||
return Err(Error::NoRouteError);
|
||||
return Err(Error::NoRoute);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -767,17 +764,13 @@ impl RouteManagerImpl {
|
||||
"Retried {} times looking for the correct device and could not find it",
|
||||
RECURSION_LIMIT
|
||||
);
|
||||
Err(Error::NoRouteError)
|
||||
Err(Error::NoRoute)
|
||||
}
|
||||
|
||||
async fn get_device_mtu(&self, device: String) -> Result<u16> {
|
||||
let mut links = self.handle.link().get().execute();
|
||||
let target_device = LinkNla::IfName(device);
|
||||
while let Some(msg) = links
|
||||
.try_next()
|
||||
.await
|
||||
.map_err(|_| Error::LinkNotFoundError)?
|
||||
{
|
||||
while let Some(msg) = links.try_next().await.map_err(|_| Error::LinkNotFound)? {
|
||||
let found = msg.nlas.iter().any(|e| *e == target_device);
|
||||
if found {
|
||||
if let Some(LinkNla::Mtu(mtu)) =
|
||||
@ -788,7 +781,7 @@ impl RouteManagerImpl {
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(Error::LinkNotFoundError)
|
||||
Err(Error::LinkNotFound)
|
||||
}
|
||||
|
||||
async fn get_destination_route(
|
||||
@ -813,11 +806,11 @@ impl RouteManagerImpl {
|
||||
let mut stream = execute_route_get_request(self.handle.clone(), message.clone());
|
||||
match stream.try_next().await {
|
||||
Ok(Some(route_msg)) => self.parse_route_message(route_msg),
|
||||
Ok(None) => Err(Error::NoRouteError),
|
||||
Ok(None) => Err(Error::NoRoute),
|
||||
Err(rtnetlink::Error::NetlinkError(nl_err)) if nl_err.code == -libc::ENETUNREACH => {
|
||||
Ok(None)
|
||||
}
|
||||
Err(err) => Err(Error::GetRouteError(err)),
|
||||
Err(err) => Err(Error::GetRoute(err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -19,16 +19,19 @@ use futures::stream::Stream;
|
||||
#[cfg(target_os = "linux")]
|
||||
use std::net::IpAddr;
|
||||
|
||||
#[allow(clippy::module_inception)]
|
||||
#[cfg(target_os = "macos")]
|
||||
#[path = "macos.rs"]
|
||||
mod imp;
|
||||
#[cfg(target_os = "macos")]
|
||||
pub(crate) use imp::listen_for_default_route_changes;
|
||||
|
||||
#[allow(clippy::module_inception)]
|
||||
#[cfg(target_os = "linux")]
|
||||
#[path = "linux.rs"]
|
||||
mod imp;
|
||||
|
||||
#[allow(clippy::module_inception)]
|
||||
#[cfg(target_os = "android")]
|
||||
#[path = "android.rs"]
|
||||
mod imp;
|
||||
|
@ -1,6 +1,6 @@
|
||||
use self::tun_provider::TunProvider;
|
||||
use crate::{logging, routing::RouteManagerHandle};
|
||||
use futures::channel::oneshot;
|
||||
use futures::{channel::oneshot, future::BoxFuture};
|
||||
use std::{
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr},
|
||||
path::{Path, PathBuf},
|
||||
@ -98,6 +98,20 @@ pub struct TunnelMonitor {
|
||||
monitor: InternalTunnelMonitor,
|
||||
}
|
||||
|
||||
/// Arguments for creating a tunnel.
|
||||
pub struct TunnelArgs<'a, L>
|
||||
where
|
||||
// L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
|
||||
L: (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Clone + Sync + 'static,
|
||||
{
|
||||
/// Resource directory.
|
||||
pub resource_dir: &'a Path,
|
||||
/// Callback function called when an event happens.
|
||||
pub on_event: L,
|
||||
/// Receiver oneshot channel for closing the tunnel.
|
||||
pub tunnel_close_rx: oneshot::Receiver<()>,
|
||||
}
|
||||
|
||||
// TODO(emilsp) move most of the openvpn tunnel details to OpenVpnTunnelMonitor
|
||||
impl TunnelMonitor {
|
||||
/// Creates a new `TunnelMonitor` that connects to the given remote and notifies `on_event`
|
||||
@ -107,12 +121,10 @@ impl TunnelMonitor {
|
||||
runtime: tokio::runtime::Handle,
|
||||
tunnel_parameters: &mut TunnelParameters,
|
||||
log_dir: &Option<PathBuf>,
|
||||
resource_dir: &Path,
|
||||
on_event: L,
|
||||
tun_provider: Arc<Mutex<TunProvider>>,
|
||||
route_manager: RouteManagerHandle,
|
||||
retry_attempt: u32,
|
||||
tunnel_close_rx: oneshot::Receiver<()>,
|
||||
route_manager: RouteManagerHandle,
|
||||
init_args: TunnelArgs<'_, L>,
|
||||
) -> Result<Self>
|
||||
where
|
||||
L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
|
||||
@ -129,9 +141,9 @@ impl TunnelMonitor {
|
||||
TunnelParameters::OpenVpn(config) => runtime.block_on(Self::start_openvpn_tunnel(
|
||||
config,
|
||||
log_file,
|
||||
resource_dir,
|
||||
on_event,
|
||||
tunnel_close_rx,
|
||||
init_args.resource_dir,
|
||||
init_args.on_event,
|
||||
init_args.tunnel_close_rx,
|
||||
#[cfg(target_os = "linux")]
|
||||
route_manager,
|
||||
)),
|
||||
@ -142,12 +154,10 @@ impl TunnelMonitor {
|
||||
runtime,
|
||||
config,
|
||||
log_file,
|
||||
resource_dir,
|
||||
on_event,
|
||||
tun_provider,
|
||||
route_manager,
|
||||
retry_attempt,
|
||||
tunnel_close_rx,
|
||||
route_manager,
|
||||
init_args,
|
||||
),
|
||||
}
|
||||
}
|
||||
@ -178,12 +188,10 @@ impl TunnelMonitor {
|
||||
runtime: tokio::runtime::Handle,
|
||||
params: &mut wireguard_types::TunnelParameters,
|
||||
log: Option<PathBuf>,
|
||||
resource_dir: &Path,
|
||||
on_event: L,
|
||||
tun_provider: Arc<Mutex<TunProvider>>,
|
||||
route_manager: RouteManagerHandle,
|
||||
retry_attempt: u32,
|
||||
tunnel_close_rx: oneshot::Receiver<()>,
|
||||
route_manager: RouteManagerHandle,
|
||||
init_args: TunnelArgs<'_, L>,
|
||||
) -> Result<Self>
|
||||
where
|
||||
L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
|
||||
@ -211,12 +219,10 @@ impl TunnelMonitor {
|
||||
None
|
||||
},
|
||||
log.as_deref(),
|
||||
resource_dir,
|
||||
on_event,
|
||||
tun_provider,
|
||||
route_manager,
|
||||
retry_attempt,
|
||||
tunnel_close_rx,
|
||||
route_manager,
|
||||
init_args,
|
||||
)?;
|
||||
Ok(TunnelMonitor {
|
||||
monitor: InternalTunnelMonitor::Wireguard(monitor),
|
||||
|
@ -310,10 +310,19 @@ impl OpenVpnMonitor<OpenVpnCommand> {
|
||||
|
||||
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
|
||||
|
||||
let openvpn_init_args = OpenVpnTunnelInitArgs {
|
||||
event_server_abort_tx: event_server_abort_tx.clone(),
|
||||
event_server_abort_rx,
|
||||
plugin_path,
|
||||
log_path,
|
||||
user_pass_file,
|
||||
proxy_auth_file,
|
||||
proxy_monitor,
|
||||
tunnel_close_rx,
|
||||
};
|
||||
Self::new_internal(
|
||||
cmd,
|
||||
event_server_abort_tx.clone(),
|
||||
event_server_abort_rx,
|
||||
openvpn_init_args,
|
||||
event_server::OpenvpnEventProxyImpl {
|
||||
on_event,
|
||||
user_pass_file_path: user_pass_file_path.clone(),
|
||||
@ -324,12 +333,6 @@ impl OpenVpnMonitor<OpenVpnCommand> {
|
||||
#[cfg(target_os = "linux")]
|
||||
ipv6_enabled,
|
||||
},
|
||||
plugin_path,
|
||||
log_path,
|
||||
user_pass_file,
|
||||
proxy_auth_file,
|
||||
proxy_monitor,
|
||||
tunnel_close_rx,
|
||||
#[cfg(windows)]
|
||||
Box::new(wintun),
|
||||
)
|
||||
@ -371,23 +374,36 @@ fn extract_routes(env: &HashMap<String, String>) -> Result<HashSet<RequiredRoute
|
||||
Ok(routes)
|
||||
}
|
||||
|
||||
impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
|
||||
async fn new_internal<L>(
|
||||
mut cmd: C,
|
||||
struct OpenVpnTunnelInitArgs {
|
||||
event_server_abort_tx: triggered::Trigger,
|
||||
event_server_abort_rx: triggered::Listener,
|
||||
on_event: L,
|
||||
plugin_path: PathBuf,
|
||||
log_path: Option<PathBuf>,
|
||||
user_pass_file: mktemp::TempFile,
|
||||
proxy_auth_file: Option<mktemp::TempFile>,
|
||||
proxy_monitor: Option<Box<dyn ProxyMonitor>>,
|
||||
tunnel_close_rx: oneshot::Receiver<()>,
|
||||
}
|
||||
|
||||
impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
|
||||
async fn new_internal<L>(
|
||||
mut cmd: C,
|
||||
init_args: OpenVpnTunnelInitArgs,
|
||||
on_event: L,
|
||||
#[cfg(windows)] wintun: Box<dyn WintunContext>,
|
||||
) -> Result<OpenVpnMonitor<C>>
|
||||
where
|
||||
L: event_server::OpenvpnEventProxy + Send + Sync + 'static,
|
||||
{
|
||||
let event_server_abort_tx = init_args.event_server_abort_tx;
|
||||
let event_server_abort_rx = init_args.event_server_abort_rx;
|
||||
let plugin_path = init_args.plugin_path;
|
||||
let log_path = init_args.log_path;
|
||||
let user_pass_file = init_args.user_pass_file;
|
||||
let proxy_auth_file = init_args.proxy_auth_file;
|
||||
let proxy_monitor = init_args.proxy_monitor;
|
||||
let tunnel_close_rx = init_args.tunnel_close_rx;
|
||||
|
||||
let (server_join_handle, ipc_path) = event_server::start(on_event, event_server_abort_rx)
|
||||
.await
|
||||
.map_err(Error::EventDispatcherError)?;
|
||||
@ -1220,23 +1236,37 @@ mod tests {
|
||||
.map_err(Error::RuntimeError)
|
||||
}
|
||||
|
||||
fn create_init_args_plugin_log(
|
||||
plugin_path: PathBuf,
|
||||
log_path: Option<PathBuf>,
|
||||
) -> OpenVpnTunnelInitArgs {
|
||||
let (_close_tx, close_rx) = oneshot::channel();
|
||||
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
|
||||
OpenVpnTunnelInitArgs {
|
||||
event_server_abort_tx,
|
||||
event_server_abort_rx,
|
||||
plugin_path,
|
||||
log_path,
|
||||
user_pass_file: TempFile::new(),
|
||||
proxy_auth_file: None,
|
||||
proxy_monitor: None,
|
||||
tunnel_close_rx: close_rx,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_init_args() -> OpenVpnTunnelInitArgs {
|
||||
create_init_args_plugin_log("".into(), None)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sets_plugin() {
|
||||
let builder = TestOpenVpnBuilder::default();
|
||||
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
|
||||
let (_close_tx, close_rx) = oneshot::channel();
|
||||
let runtime = new_runtime().unwrap();
|
||||
let openvpn_init_args = create_init_args_plugin_log("./my_test_plugin".into(), None);
|
||||
let _ = runtime.block_on(OpenVpnMonitor::new_internal(
|
||||
builder.clone(),
|
||||
event_server_abort_tx,
|
||||
event_server_abort_rx,
|
||||
openvpn_init_args,
|
||||
TestOpenvpnEventProxy {},
|
||||
"./my_test_plugin".into(),
|
||||
None,
|
||||
TempFile::new(),
|
||||
None,
|
||||
None,
|
||||
close_rx,
|
||||
#[cfg(windows)]
|
||||
Box::new(TestWintunContext {}),
|
||||
));
|
||||
@ -1249,20 +1279,13 @@ mod tests {
|
||||
#[test]
|
||||
fn sets_log() {
|
||||
let builder = TestOpenVpnBuilder::default();
|
||||
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
|
||||
let (_close_tx, close_rx) = oneshot::channel();
|
||||
let runtime = new_runtime().unwrap();
|
||||
let openvpn_init_args =
|
||||
create_init_args_plugin_log("".into(), Some(PathBuf::from("./my_test_log_file")));
|
||||
let _ = runtime.block_on(OpenVpnMonitor::new_internal(
|
||||
builder.clone(),
|
||||
event_server_abort_tx,
|
||||
event_server_abort_rx,
|
||||
openvpn_init_args,
|
||||
TestOpenvpnEventProxy {},
|
||||
"".into(),
|
||||
Some(PathBuf::from("./my_test_log_file")),
|
||||
TempFile::new(),
|
||||
None,
|
||||
None,
|
||||
close_rx,
|
||||
#[cfg(windows)]
|
||||
Box::new(TestWintunContext {}),
|
||||
));
|
||||
@ -1276,21 +1299,13 @@ mod tests {
|
||||
fn exit_successfully() {
|
||||
let mut builder = TestOpenVpnBuilder::default();
|
||||
builder.process_handle = Some(TestProcessHandle(0));
|
||||
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
|
||||
let (_close_tx, close_rx) = oneshot::channel();
|
||||
let runtime = new_runtime().unwrap();
|
||||
let openvpn_init_args = create_init_args();
|
||||
let testee = runtime
|
||||
.block_on(OpenVpnMonitor::new_internal(
|
||||
builder,
|
||||
event_server_abort_tx,
|
||||
event_server_abort_rx,
|
||||
openvpn_init_args,
|
||||
TestOpenvpnEventProxy {},
|
||||
"".into(),
|
||||
None,
|
||||
TempFile::new(),
|
||||
None,
|
||||
None,
|
||||
close_rx,
|
||||
#[cfg(windows)]
|
||||
Box::new(TestWintunContext {}),
|
||||
))
|
||||
@ -1302,21 +1317,13 @@ mod tests {
|
||||
fn exit_error() {
|
||||
let mut builder = TestOpenVpnBuilder::default();
|
||||
builder.process_handle = Some(TestProcessHandle(1));
|
||||
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
|
||||
let (_close_tx, close_rx) = oneshot::channel();
|
||||
let runtime = new_runtime().unwrap();
|
||||
let openvpn_init_args = create_init_args();
|
||||
let testee = runtime
|
||||
.block_on(OpenVpnMonitor::new_internal(
|
||||
builder,
|
||||
event_server_abort_tx,
|
||||
event_server_abort_rx,
|
||||
openvpn_init_args,
|
||||
TestOpenvpnEventProxy {},
|
||||
"".into(),
|
||||
None,
|
||||
TempFile::new(),
|
||||
None,
|
||||
None,
|
||||
close_rx,
|
||||
#[cfg(windows)]
|
||||
Box::new(TestWintunContext {}),
|
||||
))
|
||||
@ -1328,21 +1335,13 @@ mod tests {
|
||||
fn wait_closed() {
|
||||
let mut builder = TestOpenVpnBuilder::default();
|
||||
builder.process_handle = Some(TestProcessHandle(1));
|
||||
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
|
||||
let (_close_tx, close_rx) = oneshot::channel();
|
||||
let runtime = new_runtime().unwrap();
|
||||
let openvpn_init_args = create_init_args();
|
||||
let testee = runtime
|
||||
.block_on(OpenVpnMonitor::new_internal(
|
||||
builder,
|
||||
event_server_abort_tx,
|
||||
event_server_abort_rx,
|
||||
openvpn_init_args,
|
||||
TestOpenvpnEventProxy {},
|
||||
"".into(),
|
||||
None,
|
||||
TempFile::new(),
|
||||
None,
|
||||
None,
|
||||
close_rx,
|
||||
#[cfg(windows)]
|
||||
Box::new(TestWintunContext {}),
|
||||
))
|
||||
@ -1354,21 +1353,13 @@ mod tests {
|
||||
#[test]
|
||||
fn failed_process_start() {
|
||||
let builder = TestOpenVpnBuilder::default();
|
||||
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
|
||||
let (_close_tx, close_rx) = oneshot::channel();
|
||||
let runtime = new_runtime().unwrap();
|
||||
let openvpn_init_args = create_init_args();
|
||||
let result = runtime
|
||||
.block_on(OpenVpnMonitor::new_internal(
|
||||
builder,
|
||||
event_server_abort_tx,
|
||||
event_server_abort_rx,
|
||||
openvpn_init_args,
|
||||
TestOpenvpnEventProxy {},
|
||||
"".into(),
|
||||
None,
|
||||
TempFile::new(),
|
||||
None,
|
||||
None,
|
||||
close_rx,
|
||||
#[cfg(windows)]
|
||||
Box::new(TestWintunContext {}),
|
||||
))
|
||||
|
@ -22,6 +22,12 @@ pub enum Error {
|
||||
/// Factory of tunnel devices on Unix systems.
|
||||
pub struct UnixTunProvider;
|
||||
|
||||
impl Default for UnixTunProvider {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl UnixTunProvider {
|
||||
pub fn new() -> Self {
|
||||
UnixTunProvider
|
||||
|
@ -112,7 +112,7 @@ pub unsafe extern "system" fn wg_go_logging_callback(
|
||||
|
||||
let level = match level {
|
||||
WG_GO_LOG_VERBOSE => LogLevel::Verbose,
|
||||
WG_GO_LOG_ERROR | _ => LogLevel::Error,
|
||||
_ => LogLevel::Error,
|
||||
};
|
||||
log_inner(logfile, level, "wireguard-go", &managed_msg);
|
||||
}
|
||||
@ -121,5 +121,5 @@ pub unsafe extern "system" fn wg_go_logging_callback(
|
||||
pub type WgLogLevel = u32;
|
||||
// wireguard-go supports log levels 0 through 3 with 3 being the most verbose
|
||||
// const WG_GO_LOG_SILENT: WgLogLevel = 0;
|
||||
const WG_GO_LOG_ERROR: WgLogLevel = 1;
|
||||
// const WG_GO_LOG_ERROR: WgLogLevel = 1;
|
||||
const WG_GO_LOG_VERBOSE: WgLogLevel = 2;
|
||||
|
@ -1,15 +1,11 @@
|
||||
use self::config::Config;
|
||||
#[cfg(not(windows))]
|
||||
use super::tun_provider;
|
||||
use super::{tun_provider::TunProvider, TunnelEvent, TunnelMetadata};
|
||||
use super::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata};
|
||||
use crate::routing::{self, RequiredRoute, RouteManagerHandle};
|
||||
use futures::future::{abortable, AbortHandle as FutureAbortHandle, BoxFuture, Future};
|
||||
#[cfg(windows)]
|
||||
use futures::{channel::mpsc, StreamExt};
|
||||
use futures::{
|
||||
channel::oneshot,
|
||||
future::{abortable, AbortHandle as FutureAbortHandle},
|
||||
Future,
|
||||
};
|
||||
#[cfg(target_os = "linux")]
|
||||
use lazy_static::lazy_static;
|
||||
#[cfg(target_os = "linux")]
|
||||
@ -54,6 +50,7 @@ mod wireguard_nt;
|
||||
use self::wireguard_go::WgGoTunnel;
|
||||
|
||||
type Result<T> = std::result::Result<T, Error>;
|
||||
type EventCallback = Box<dyn (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Sync + 'static>;
|
||||
|
||||
/// Errors that can happen in the Wireguard tunnel monitor.
|
||||
#[derive(err_derive::Error, Debug)]
|
||||
@ -104,12 +101,7 @@ pub struct WireguardMonitor {
|
||||
/// Tunnel implementation
|
||||
tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>,
|
||||
/// Callback to signal tunnel events
|
||||
event_callback: Box<
|
||||
dyn (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static,
|
||||
>,
|
||||
event_callback: EventCallback,
|
||||
close_msg_receiver: sync_mpsc::Receiver<CloseMsg>,
|
||||
pinger_stop_sender: sync_mpsc::Sender<()>,
|
||||
_obfuscator: Option<ObfuscatorHandle>,
|
||||
@ -208,13 +200,13 @@ impl WireguardMonitor {
|
||||
mut config: Config,
|
||||
psk_negotiation: Option<PublicKey>,
|
||||
log_path: Option<&Path>,
|
||||
resource_dir: &Path,
|
||||
on_event: F,
|
||||
tun_provider: Arc<Mutex<TunProvider>>,
|
||||
route_manager: RouteManagerHandle,
|
||||
retry_attempt: u32,
|
||||
tunnel_close_rx: oneshot::Receiver<()>,
|
||||
route_manager: RouteManagerHandle,
|
||||
init_args: TunnelArgs<'_, F>,
|
||||
) -> Result<WireguardMonitor> {
|
||||
let on_event = init_args.on_event;
|
||||
|
||||
let endpoint_addrs: Vec<IpAddr> =
|
||||
config.peers.iter().map(|peer| peer.endpoint.ip()).collect();
|
||||
let (close_msg_sender, close_msg_receiver) = sync_mpsc::channel();
|
||||
@ -228,7 +220,7 @@ impl WireguardMonitor {
|
||||
runtime.clone(),
|
||||
&Self::patch_allowed_ips(&config, psk_negotiation.is_some()),
|
||||
log_path,
|
||||
resource_dir,
|
||||
init_args.resource_dir,
|
||||
tun_provider,
|
||||
#[cfg(target_os = "windows")]
|
||||
setup_done_tx,
|
||||
@ -351,7 +343,7 @@ impl WireguardMonitor {
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
if tunnel_close_rx.await.is_ok() {
|
||||
if init_args.tunnel_close_rx.await.is_ok() {
|
||||
monitor_handle.abort();
|
||||
let _ = close_msg_sender.send(CloseMsg::Stop);
|
||||
}
|
||||
|
@ -4,10 +4,10 @@ use super::wireguard_kernel::wg_message::{DeviceMessage, DeviceNla, PeerNla};
|
||||
#[derive(err_derive::Error, Debug, PartialEq)]
|
||||
pub enum Error {
|
||||
#[error(display = "Failed to parse peer pubkey from string \"_0\"")]
|
||||
PubKeyParseError(String, #[error(source)] hex::FromHexError),
|
||||
PubKeyParse(String, #[error(source)] hex::FromHexError),
|
||||
|
||||
#[error(display = "Failed to parse integer from string \"_0\"")]
|
||||
IntParseError(String, #[error(source)] std::num::ParseIntError),
|
||||
IntParse(String, #[error(source)] std::num::ParseIntError),
|
||||
|
||||
#[error(display = "Device no longer exists")]
|
||||
NoTunnelDevice,
|
||||
@ -47,7 +47,7 @@ impl Stats {
|
||||
"public_key" => {
|
||||
let mut buffer = [0u8; 32];
|
||||
hex::decode_to_slice(value, &mut buffer)
|
||||
.map_err(|err| Error::PubKeyParseError(value.to_string(), err))?;
|
||||
.map_err(|err| Error::PubKeyParse(value.to_string(), err))?;
|
||||
peer = Some(buffer);
|
||||
tx_bytes = None;
|
||||
rx_bytes = None;
|
||||
@ -57,7 +57,7 @@ impl Stats {
|
||||
value
|
||||
.trim()
|
||||
.parse()
|
||||
.map_err(|err| Error::IntParseError(value.to_string(), err))?,
|
||||
.map_err(|err| Error::IntParse(value.to_string(), err))?,
|
||||
);
|
||||
}
|
||||
"tx_bytes" => {
|
||||
@ -65,7 +65,7 @@ impl Stats {
|
||||
value
|
||||
.trim()
|
||||
.parse()
|
||||
.map_err(|err| Error::IntParseError(value.to_string(), err))?,
|
||||
.map_err(|err| Error::IntParse(value.to_string(), err))?,
|
||||
);
|
||||
}
|
||||
|
||||
@ -145,7 +145,7 @@ mod test {
|
||||
|
||||
assert_eq!(
|
||||
Stats::parse_config_str(invalid_input),
|
||||
Err(Error::IntParseError(invalid_str, int_err))
|
||||
Err(Error::IntParse(invalid_str, int_err))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -33,16 +33,16 @@ pub use nm_tunnel::NetworkManagerTunnel;
|
||||
#[error(no_from)]
|
||||
pub enum Error {
|
||||
#[error(display = "Failed to decode netlink message")]
|
||||
DecodeError(#[error(source)] DecodeError),
|
||||
Decode(#[error(source)] DecodeError),
|
||||
|
||||
#[error(display = "Failed to execute netlink control request")]
|
||||
NetlinkControlMessageError(#[error(source)] nl_message::Error),
|
||||
NetlinkControlMessage(#[error(source)] nl_message::Error),
|
||||
|
||||
#[error(display = "Failed to open netlink socket")]
|
||||
NetlinkSocketError(#[error(source)] std::io::Error),
|
||||
NetlinkSocket(#[error(source)] std::io::Error),
|
||||
|
||||
#[error(display = "Failed to send netlink control request")]
|
||||
NetlinkRequestError(#[error(source)] netlink_proto::Error<NetlinkControlMessage>),
|
||||
NetlinkRequest(#[error(source)] netlink_proto::Error<NetlinkControlMessage>),
|
||||
|
||||
#[error(display = "WireGuard netlink interface unavailable. Is the kernel module loaded?")]
|
||||
WireguardNetlinkInterfaceUnavailable,
|
||||
@ -60,25 +60,25 @@ pub enum Error {
|
||||
NoDevice,
|
||||
|
||||
#[error(display = "Failed to get config: _0")]
|
||||
WgGetConfError(netlink_packet_core::error::ErrorMessage),
|
||||
WgGetConf(netlink_packet_core::error::ErrorMessage),
|
||||
|
||||
#[error(display = "Failed to apply config: _0")]
|
||||
WgSetConfError(netlink_packet_core::error::ErrorMessage),
|
||||
WgSetConf(netlink_packet_core::error::ErrorMessage),
|
||||
|
||||
#[error(display = "Interface name too long")]
|
||||
InterfaceNameError,
|
||||
InterfaceName,
|
||||
|
||||
#[error(display = "Send request error")]
|
||||
SendRequestError(#[error(source)] NetlinkError<DeviceMessage>),
|
||||
SendRequest(#[error(source)] NetlinkError<DeviceMessage>),
|
||||
|
||||
#[error(display = "Create device error")]
|
||||
NetlinkCreateDeviceError(#[error(source)] rtnetlink::Error),
|
||||
NetlinkCreateDevice(#[error(source)] rtnetlink::Error),
|
||||
|
||||
#[error(display = "Add IP to device error")]
|
||||
NetlinkSetIpError(rtnetlink::Error),
|
||||
NetlinkSetIp(rtnetlink::Error),
|
||||
|
||||
#[error(display = "Failed to delete device")]
|
||||
DeleteDeviceError(#[error(source)] rtnetlink::Error),
|
||||
DeleteDevice(#[error(source)] rtnetlink::Error),
|
||||
|
||||
#[error(display = "NetworkManager error")]
|
||||
NetworkManager(#[error(source)] nm_tunnel::Error),
|
||||
@ -98,7 +98,7 @@ impl Handle {
|
||||
pub async fn connect() -> Result<Self, Error> {
|
||||
let message_type = Self::get_wireguard_message_type().await?;
|
||||
let (conn, wireguard_connection, _messages) =
|
||||
netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocketError)?;
|
||||
netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocket)?;
|
||||
let wg_handle = WireguardConnection {
|
||||
message_type,
|
||||
connection: wireguard_connection,
|
||||
@ -106,7 +106,7 @@ impl Handle {
|
||||
let (abortable_connection, wg_abort_handle) = abortable(conn);
|
||||
tokio::spawn(abortable_connection);
|
||||
let (conn, route_handle, _messages) =
|
||||
rtnetlink::new_connection().map_err(Error::NetlinkSocketError)?;
|
||||
rtnetlink::new_connection().map_err(Error::NetlinkSocket)?;
|
||||
let (abortable_connection, route_abort_handle) = abortable(conn);
|
||||
tokio::spawn(abortable_connection);
|
||||
|
||||
@ -120,21 +120,21 @@ impl Handle {
|
||||
|
||||
async fn get_wireguard_message_type() -> Result<u16, Error> {
|
||||
let (conn, mut handle, _messages) =
|
||||
netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocketError)?;
|
||||
netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocket)?;
|
||||
let (conn, abort_handle) = abortable(conn);
|
||||
tokio::spawn(conn);
|
||||
|
||||
let result = async move {
|
||||
let mut message: NetlinkMessage<NetlinkControlMessage> =
|
||||
NetlinkControlMessage::get_netlink_family_id(CString::new("wireguard").unwrap())
|
||||
.map_err(Error::NetlinkControlMessageError)?
|
||||
.map_err(Error::NetlinkControlMessage)?
|
||||
.into();
|
||||
|
||||
message.header.flags = NLM_F_REQUEST | NLM_F_ACK;
|
||||
|
||||
let mut req = handle
|
||||
.request(message, SocketAddr::new(0, 0))
|
||||
.map_err(Error::NetlinkRequestError)?;
|
||||
.map_err(Error::NetlinkRequest)?;
|
||||
let response = req.next().await;
|
||||
if let Some(response) = response {
|
||||
if let NetlinkPayload::InnerMessage(msg) = response.payload {
|
||||
@ -177,14 +177,14 @@ impl Handle {
|
||||
let mut response = self
|
||||
.route_handle
|
||||
.request(add_request)
|
||||
.map_err(Error::NetlinkCreateDeviceError)?;
|
||||
.map_err(Error::NetlinkCreateDevice)?;
|
||||
while let Some(response_message) = response.next().await {
|
||||
if let NetlinkPayload::Error(err) = response_message.payload {
|
||||
// if the device exists, verify that it's a wireguard device
|
||||
if -err.code != libc::EEXIST {
|
||||
return Err(Error::NetlinkCreateDeviceError(
|
||||
rtnetlink::Error::NetlinkError(err),
|
||||
));
|
||||
return Err(Error::NetlinkCreateDevice(rtnetlink::Error::NetlinkError(
|
||||
err,
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -208,9 +208,9 @@ impl Handle {
|
||||
let mut response = self
|
||||
.route_handle
|
||||
.request(request)
|
||||
.map_err(Error::NetlinkSetIpError)?;
|
||||
.map_err(Error::NetlinkSetIp)?;
|
||||
while let Some(response_message) = response.next().await {
|
||||
consume_netlink_error(response_message, Error::NetlinkSetIpError)?;
|
||||
consume_netlink_error(response_message, Error::NetlinkSetIp)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@ -226,9 +226,9 @@ impl Handle {
|
||||
let mut response = self
|
||||
.route_handle
|
||||
.request(request)
|
||||
.map_err(Error::DeleteDeviceError)?;
|
||||
.map_err(Error::DeleteDevice)?;
|
||||
while let Some(message) = response.next().await {
|
||||
consume_netlink_error(message, Error::DeleteDeviceError)?;
|
||||
consume_netlink_error(message, Error::DeleteDevice)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@ -269,7 +269,7 @@ impl WireguardConnection {
|
||||
let mut response = self
|
||||
.connection
|
||||
.request(netlink_message, SocketAddr::new(0, 0))
|
||||
.map_err(Error::SendRequestError)?;
|
||||
.map_err(Error::SendRequest)?;
|
||||
match response.next().await {
|
||||
Some(received_message) => match received_message.payload {
|
||||
NetlinkPayload::InnerMessage(inner) => Ok(inner),
|
||||
@ -277,7 +277,7 @@ impl WireguardConnection {
|
||||
if err.code == -libc::ENODEV {
|
||||
Err(Error::NoDevice)
|
||||
} else {
|
||||
Err(Error::WgGetConfError(err))
|
||||
Err(Error::WgGetConf(err))
|
||||
}
|
||||
}
|
||||
anything_else => {
|
||||
@ -297,11 +297,11 @@ impl WireguardConnection {
|
||||
let mut request = self
|
||||
.connection
|
||||
.request(netlink_message, SocketAddr::new(0, 0))
|
||||
.map_err(Error::SendRequestError)?;
|
||||
.map_err(Error::SendRequest)?;
|
||||
|
||||
while let Some(response) = request.next().await {
|
||||
if let NetlinkPayload::Error(err) = response.payload {
|
||||
return Err(Error::WgSetConfError(err));
|
||||
return Err(Error::WgSetConf(err));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
@ -110,9 +110,9 @@ impl DeviceMessage {
|
||||
}
|
||||
|
||||
pub fn get_by_name(message_type: u16, name: String) -> Result<Self, Error> {
|
||||
let c_name = CString::new(name).map_err(|_| Error::InterfaceNameError)?;
|
||||
let c_name = CString::new(name).map_err(|_| Error::InterfaceName)?;
|
||||
if c_name.as_bytes_with_nul().len() > libc::IFNAMSIZ {
|
||||
return Err(Error::InterfaceNameError);
|
||||
return Err(Error::InterfaceName);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
@ -178,9 +178,7 @@ impl NetlinkDeserializable<DeviceMessage> for DeviceMessage {
|
||||
let new_payload = &payload[mem::size_of::<libc::genlmsghdr>()..];
|
||||
let mut nlas = vec![];
|
||||
for buf in NlasIterator::new(new_payload) {
|
||||
nlas.push(
|
||||
DeviceNla::parse(&buf.map_err(Error::DecodeError)?).map_err(Error::DecodeError)?,
|
||||
);
|
||||
nlas.push(DeviceNla::parse(&buf.map_err(Error::Decode)?).map_err(Error::Decode)?);
|
||||
}
|
||||
|
||||
Ok(DeviceMessage {
|
||||
@ -391,13 +389,13 @@ impl Nla for PeerNla {
|
||||
InetAddr::V4(sockaddr_in) => {
|
||||
// SAFETY: `sockaddr_in` has no padding bytes
|
||||
buffer
|
||||
.write(unsafe { struct_as_slice(sockaddr_in) })
|
||||
.write_all(unsafe { struct_as_slice(sockaddr_in) })
|
||||
.expect("Buffer too small for sockaddr_in");
|
||||
}
|
||||
InetAddr::V6(sockaddr_in6) => {
|
||||
// SAFETY: `sockaddr_in` has no padding bytes
|
||||
buffer
|
||||
.write(unsafe { struct_as_slice(sockaddr_in6) })
|
||||
.write_all(unsafe { struct_as_slice(sockaddr_in6) })
|
||||
.expect("Buffer too small for sockaddr_in6");
|
||||
}
|
||||
},
|
||||
@ -408,7 +406,7 @@ impl Nla for PeerNla {
|
||||
let timespec: &libc::timespec = last_handshake.as_ref();
|
||||
// SAFETY: `timespec` has no padding bytes
|
||||
buffer
|
||||
.write(unsafe { struct_as_slice(timespec) })
|
||||
.write_all(unsafe { struct_as_slice(timespec) })
|
||||
.expect("Buffer too small for timespec");
|
||||
}
|
||||
RxBytes(num_bytes) | TxBytes(num_bytes) => NativeEndian::write_u64(buffer, *num_bytes),
|
||||
@ -535,7 +533,7 @@ impl Nla for AllowedIpNla {
|
||||
}
|
||||
IpAddr(ip_addr) => {
|
||||
buffer
|
||||
.write(&ip_addr_to_bytes(ip_addr))
|
||||
.write_all(&ip_addr_to_bytes(ip_addr))
|
||||
.expect("Buffer too small for AllowedIpNla::IpAddr");
|
||||
}
|
||||
CidrMask(cidr_mask) => buffer[0] = *cidr_mask,
|
||||
|
@ -6,7 +6,9 @@ use super::{
|
||||
use crate::{
|
||||
firewall::FirewallPolicy,
|
||||
routing::RouteManager,
|
||||
tunnel::{self, tun_provider::TunProvider, TunnelEvent, TunnelMetadata, TunnelMonitor},
|
||||
tunnel::{
|
||||
self, tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata, TunnelMonitor,
|
||||
},
|
||||
};
|
||||
use cfg_if::cfg_if;
|
||||
use futures::{
|
||||
@ -142,16 +144,20 @@ impl ConnectingState {
|
||||
}
|
||||
};
|
||||
|
||||
let init_args = TunnelArgs {
|
||||
resource_dir: &resource_dir,
|
||||
on_event: on_tunnel_event,
|
||||
tunnel_close_rx,
|
||||
};
|
||||
|
||||
let block_reason = match TunnelMonitor::start(
|
||||
runtime,
|
||||
&mut tunnel_parameters,
|
||||
&log_dir,
|
||||
&resource_dir,
|
||||
on_tunnel_event,
|
||||
tun_provider,
|
||||
route_manager_handle,
|
||||
retry_attempt,
|
||||
tunnel_close_rx,
|
||||
route_manager_handle,
|
||||
init_args,
|
||||
) {
|
||||
Ok(monitor) => {
|
||||
let reason = Self::wait_for_tunnel_monitor(monitor, retry_attempt);
|
||||
|
@ -132,23 +132,25 @@ pub async fn spawn(
|
||||
let (shutdown_tx, shutdown_rx) = oneshot::channel();
|
||||
|
||||
let weak_command_tx = Arc::downgrade(&command_tx);
|
||||
let state_machine = TunnelStateMachine::new(
|
||||
initial_settings,
|
||||
weak_command_tx,
|
||||
offline_state_listener,
|
||||
|
||||
let init_args = TunnelStateMachineInitArgs {
|
||||
settings: initial_settings,
|
||||
command_tx: weak_command_tx,
|
||||
offline_state_tx: offline_state_listener,
|
||||
tunnel_parameters_generator,
|
||||
tun_provider,
|
||||
log_dir,
|
||||
resource_dir,
|
||||
command_rx,
|
||||
commands_rx: command_rx,
|
||||
#[cfg(target_os = "windows")]
|
||||
volume_update_rx,
|
||||
#[cfg(target_os = "macos")]
|
||||
exclusion_gid,
|
||||
#[cfg(target_os = "android")]
|
||||
android_context,
|
||||
)
|
||||
.await?;
|
||||
};
|
||||
|
||||
let state_machine = TunnelStateMachine::new(init_args).await?;
|
||||
|
||||
#[cfg(windows)]
|
||||
let split_tunnel = state_machine.shared_values.split_tunnel.handle();
|
||||
@ -219,20 +221,35 @@ struct TunnelStateMachine {
|
||||
shared_values: SharedTunnelStateValues,
|
||||
}
|
||||
|
||||
impl TunnelStateMachine {
|
||||
async fn new(
|
||||
/// Tunnel state machine initialization arguments arguments
|
||||
struct TunnelStateMachineInitArgs<G: TunnelParametersGenerator> {
|
||||
settings: InitialTunnelState,
|
||||
command_tx: std::sync::Weak<mpsc::UnboundedSender<TunnelCommand>>,
|
||||
offline_state_tx: mpsc::UnboundedSender<bool>,
|
||||
tunnel_parameters_generator: impl TunnelParametersGenerator,
|
||||
tunnel_parameters_generator: G,
|
||||
tun_provider: TunProvider,
|
||||
log_dir: Option<PathBuf>,
|
||||
resource_dir: PathBuf,
|
||||
commands_rx: mpsc::UnboundedReceiver<TunnelCommand>,
|
||||
#[cfg(target_os = "windows")] volume_update_rx: mpsc::UnboundedReceiver<()>,
|
||||
#[cfg(target_os = "macos")] exclusion_gid: u32,
|
||||
#[cfg(target_os = "android")] android_context: AndroidContext,
|
||||
#[cfg(target_os = "windows")]
|
||||
volume_update_rx: mpsc::UnboundedReceiver<()>,
|
||||
#[cfg(target_os = "macos")]
|
||||
exclusion_gid: u32,
|
||||
#[cfg(target_os = "android")]
|
||||
android_context: AndroidContext,
|
||||
}
|
||||
|
||||
impl TunnelStateMachine {
|
||||
async fn new(
|
||||
args: TunnelStateMachineInitArgs<impl TunnelParametersGenerator>,
|
||||
) -> Result<Self, Error> {
|
||||
#[cfg(target_os = "windows")]
|
||||
let volume_update_rx = args.volume_update_rx;
|
||||
#[cfg(target_os = "macos")]
|
||||
let exclusion_gid = args.exclusion_gid;
|
||||
#[cfg(target_os = "android")]
|
||||
let android_context = args.android_context;
|
||||
|
||||
let runtime = tokio::runtime::Handle::current();
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
@ -242,20 +259,24 @@ impl TunnelStateMachine {
|
||||
let power_mgmt_rx = crate::windows::window::PowerManagementListener::new();
|
||||
|
||||
#[cfg(windows)]
|
||||
let split_tunnel =
|
||||
split_tunnel::SplitTunnel::new(runtime.clone(), command_tx.clone(), volume_update_rx)
|
||||
let split_tunnel = split_tunnel::SplitTunnel::new(
|
||||
runtime.clone(),
|
||||
args.command_tx.clone(),
|
||||
volume_update_rx,
|
||||
)
|
||||
.map_err(Error::InitSplitTunneling)?;
|
||||
|
||||
let args = FirewallArguments {
|
||||
initial_state: if settings.block_when_disconnected || !settings.reset_firewall {
|
||||
InitialFirewallState::Blocked(settings.allowed_endpoint.clone())
|
||||
let fw_args = FirewallArguments {
|
||||
initial_state: if args.settings.block_when_disconnected || !args.settings.reset_firewall
|
||||
{
|
||||
InitialFirewallState::Blocked(args.settings.allowed_endpoint.clone())
|
||||
} else {
|
||||
InitialFirewallState::None
|
||||
},
|
||||
allow_lan: settings.allow_lan,
|
||||
allow_lan: args.settings.allow_lan,
|
||||
};
|
||||
|
||||
let firewall = Firewall::from_args(args).map_err(Error::InitFirewallError)?;
|
||||
let firewall = Firewall::from_args(fw_args).map_err(Error::InitFirewallError)?;
|
||||
let route_manager = RouteManager::new(HashSet::new())
|
||||
.await
|
||||
.map_err(Error::InitRouteManagerError)?;
|
||||
@ -267,20 +288,20 @@ impl TunnelStateMachine {
|
||||
.handle()
|
||||
.map_err(Error::InitRouteManagerError)?,
|
||||
#[cfg(target_os = "macos")]
|
||||
command_tx.clone(),
|
||||
args.command_tx.clone(),
|
||||
)
|
||||
.map_err(Error::InitDnsMonitorError)?;
|
||||
|
||||
let (offline_tx, mut offline_rx) = mpsc::unbounded();
|
||||
let initial_offline_state_tx = offline_state_tx.clone();
|
||||
let initial_offline_state_tx = args.offline_state_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
while let Some(offline) = offline_rx.next().await {
|
||||
if let Some(tx) = command_tx.upgrade() {
|
||||
if let Some(tx) = args.command_tx.upgrade() {
|
||||
let _ = tx.unbounded_send(TunnelCommand::IsOffline(offline));
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
let _ = offline_state_tx.unbounded_send(offline);
|
||||
let _ = args.offline_state_tx.unbounded_send(offline);
|
||||
}
|
||||
});
|
||||
let mut offline_monitor = offline::spawn_monitor(
|
||||
@ -301,7 +322,7 @@ impl TunnelStateMachine {
|
||||
|
||||
#[cfg(windows)]
|
||||
split_tunnel
|
||||
.set_paths_sync(&settings.exclude_paths)
|
||||
.set_paths_sync(&args.settings.exclude_paths)
|
||||
.map_err(Error::InitSplitTunneling)?;
|
||||
|
||||
let mut shared_values = SharedTunnelStateValues {
|
||||
@ -312,15 +333,15 @@ impl TunnelStateMachine {
|
||||
dns_monitor,
|
||||
route_manager,
|
||||
_offline_monitor: offline_monitor,
|
||||
allow_lan: settings.allow_lan,
|
||||
block_when_disconnected: settings.block_when_disconnected,
|
||||
allow_lan: args.settings.allow_lan,
|
||||
block_when_disconnected: args.settings.block_when_disconnected,
|
||||
is_offline,
|
||||
dns_servers: settings.dns_servers,
|
||||
allowed_endpoint: settings.allowed_endpoint,
|
||||
tunnel_parameters_generator: Box::new(tunnel_parameters_generator),
|
||||
tun_provider: Arc::new(Mutex::new(tun_provider)),
|
||||
log_dir,
|
||||
resource_dir,
|
||||
dns_servers: args.settings.dns_servers,
|
||||
allowed_endpoint: args.settings.allowed_endpoint,
|
||||
tunnel_parameters_generator: Box::new(args.tunnel_parameters_generator),
|
||||
tun_provider: Arc::new(Mutex::new(args.tun_provider)),
|
||||
log_dir: args.log_dir,
|
||||
resource_dir: args.resource_dir,
|
||||
#[cfg(target_os = "linux")]
|
||||
connectivity_check_was_enabled: None,
|
||||
#[cfg(target_os = "macos")]
|
||||
@ -331,11 +352,11 @@ impl TunnelStateMachine {
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let (initial_state, _) =
|
||||
DisconnectedState::enter(&mut shared_values, settings.reset_firewall);
|
||||
DisconnectedState::enter(&mut shared_values, args.settings.reset_firewall);
|
||||
|
||||
Ok(TunnelStateMachine {
|
||||
current_state: Some(initial_state),
|
||||
commands: commands_rx.fuse(),
|
||||
commands: args.commands_rx.fuse(),
|
||||
shared_values,
|
||||
})
|
||||
})
|
||||
|
@ -59,6 +59,7 @@ const MAXIMUM_SUPPORTED_MINOR_VERSION: u32 = 26;
|
||||
const NM_DEVICE_STATE_CHANGED: &str = "StateChanged";
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
type NetworkSettings<'a> = HashMap<String, HashMap<String, Variant<Box<dyn RefArg + 'a>>>>;
|
||||
|
||||
#[derive(err_derive::Error, Debug)]
|
||||
pub enum Error {
|
||||
@ -447,10 +448,8 @@ impl NetworkManager {
|
||||
|
||||
let device = self.as_path(&device_path);
|
||||
// Get the last applied connection
|
||||
let (mut settings, version_id): (
|
||||
HashMap<String, HashMap<String, Variant<Box<dyn RefArg>>>>,
|
||||
u64,
|
||||
) = device.method_call(NM_DEVICE, "GetAppliedConnection", (0u32,))?;
|
||||
let (mut settings, version_id): (NetworkSettings, u64) =
|
||||
device.method_call(NM_DEVICE, "GetAppliedConnection", (0u32,))?;
|
||||
|
||||
// Keep changed routes.
|
||||
// These routes were modified outside NM, likely by RouteManager.
|
||||
@ -576,7 +575,7 @@ impl NetworkManager {
|
||||
}
|
||||
|
||||
fn update_dns_config<'a, T>(
|
||||
settings: &mut HashMap<String, HashMap<String, Variant<Box<dyn RefArg + 'a>>>>,
|
||||
settings: &mut NetworkSettings<'a>,
|
||||
ip_protocol: &'static str,
|
||||
servers: T,
|
||||
) where
|
||||
|
@ -349,7 +349,7 @@ impl SystemdResolved {
|
||||
.map_err(Error::DBusRpcError)
|
||||
}
|
||||
|
||||
fn link_disable_dns_over_tls<'a, 'b: 'a>(&'a self, interface_index: u32) -> Result<()> {
|
||||
fn link_disable_dns_over_tls(&self, interface_index: u32) -> Result<()> {
|
||||
let link_object_path = self
|
||||
.fetch_link(interface_index)
|
||||
.map_err(|e| Error::GetLinkError(Box::new(e)))?;
|
||||
|
@ -95,6 +95,7 @@ fn default_wgnt_setting() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[allow(clippy::derivable_impls)]
|
||||
impl Default for TunnelOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
|
Loading…
x
Reference in New Issue
Block a user