Fix the large majority of clippy warnings

This commit fixes most of the remaining clippy warnings in the codebase.
These warnings were the more semantically difficult ones to fix.
There are some warnings that remain from the rebase that will be fixed
in the upcoming PR.
This commit is contained in:
Jonathan 2022-06-13 10:49:46 +02:00
parent b6b80b9ffe
commit d3da8745c8
46 changed files with 450 additions and 423 deletions

View File

@ -84,13 +84,6 @@ impl StringValue {
}
}
impl StringValue {
/// Clones the internal string value.
pub fn to_string(&self) -> String {
self.0.clone()
}
}
impl Deref for StringValue {
type Target = str;

1
clippy.toml Normal file
View File

@ -0,0 +1 @@
enum-variant-size-threshold = 1000

View File

@ -10,19 +10,16 @@ use tokio::{
#[error(no_from)]
pub enum Error {
#[error(display = "Failed to open the address cache file")]
OpenAddressCache(#[error(source)] io::Error),
Open(#[error(source)] io::Error),
#[error(display = "Failed to read the address cache file")]
ReadAddressCache(#[error(source)] io::Error),
Read(#[error(source)] io::Error),
#[error(display = "Failed to parse the address cache file")]
ParseAddressCache,
Parse,
#[error(display = "Failed to update the address cache file")]
WriteAddressCache(#[error(source)] io::Error),
#[error(display = "The address cache is empty")]
EmptyAddressCache,
Write(#[error(source)] io::Error),
}
#[derive(Clone)]
@ -68,7 +65,7 @@ impl AddressCache {
self.inner.lock().await.address
}
pub async fn set_address(&self, address: SocketAddr) -> io::Result<()> {
pub async fn set_address(&self, address: SocketAddr) -> Result<(), Error> {
let mut inner = self.inner.lock().await;
if address != inner.address {
self.save_to_disk(&address).await?;
@ -77,17 +74,21 @@ impl AddressCache {
Ok(())
}
async fn save_to_disk(&self, address: &SocketAddr) -> io::Result<()> {
async fn save_to_disk(&self, address: &SocketAddr) -> Result<(), Error> {
let write_path = match self.write_path.as_ref() {
Some(write_path) => write_path,
None => return Ok(()),
};
let mut file = crate::fs::AtomicFile::new(write_path.to_path_buf()).await?;
let mut file = crate::fs::AtomicFile::new(write_path.to_path_buf())
.await
.map_err(Error::Open)?;
let mut contents = address.to_string();
contents += "\n";
file.write_all(contents.as_bytes()).await?;
file.finalize().await
file.write_all(contents.as_bytes())
.await
.map_err(Error::Write)?;
file.finalize().await.map_err(Error::Write)
}
}
@ -103,12 +104,10 @@ impl AddressCacheInner {
}
async fn read_address_file(path: &Path) -> Result<SocketAddr, Error> {
let mut file = fs::File::open(path)
.await
.map_err(Error::OpenAddressCache)?;
let mut file = fs::File::open(path).await.map_err(Error::Open)?;
let mut address = String::new();
file.read_to_string(&mut address)
.await
.map_err(Error::ReadAddressCache)?;
address.trim().parse().map_err(|_| Error::ParseAddressCache)
.map_err(Error::Read)?;
address.trim().parse().map_err(|_| Error::Parse)
}

View File

@ -304,7 +304,7 @@ impl Service<Uri> for HttpsConnectorWithSni {
)
.await?;
let tls_stream = TlsStream::connect_https(socket, &hostname).await?;
Ok::<_, io::Error>(ApiConnection::Direct(tls_stream))
Ok::<_, io::Error>(ApiConnection::Direct(Box::new(tls_stream)))
}
InnerConnectionMode::Proxied(proxy_config) => {
let socket = Self::open_socket(
@ -320,7 +320,7 @@ impl Service<Uri> for HttpsConnectorWithSni {
addr,
);
let tls_stream = TlsStream::connect_https(proxy, &hostname).await?;
Ok(ApiConnection::Proxied(tls_stream))
Ok(ApiConnection::Proxied(Box::new(tls_stream)))
}
}
};

View File

@ -236,7 +236,7 @@ impl Runtime {
new_address_callback: impl ApiEndpointUpdateCallback + Send + Sync + 'static,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> rest::RequestServiceHandle {
let service_handle = rest::RequestService::new(
let service_handle = rest::RequestService::spawn(
sni_hostname,
self.api_availability.handle(),
self.address_cache.clone(),

View File

@ -132,8 +132,8 @@ impl ApiConnectionMode {
/// Stream that is either a regular TLS stream or TLS via shadowsocks
pub enum ApiConnection {
Direct(TlsStream<TcpStream>),
Proxied(TlsStream<ProxyClientStream<TcpStream>>),
Direct(Box<TlsStream<TcpStream>>),
Proxied(Box<TlsStream<ProxyClientStream<TcpStream>>>),
}
impl AsyncRead for ApiConnection {

View File

@ -130,7 +130,7 @@ impl ServerRelayList {
) {
let openvpn_endpoint_data = openvpn.ports;
for mut openvpn_relay in openvpn.relays.into_iter() {
openvpn_relay.to_lower();
openvpn_relay.convert_to_lowercase();
if let Some((country_code, city_code)) = split_location_code(&openvpn_relay.location) {
if let Some(country) = countries.get_mut(country_code) {
if let Some(city) = country
@ -184,7 +184,7 @@ impl ServerRelayList {
};
for mut wireguard_relay in relays {
wireguard_relay.relay.to_lower();
wireguard_relay.relay.convert_to_lowercase();
if let Some((country_code, city_code)) =
split_location_code(&wireguard_relay.relay.location)
{
@ -235,7 +235,7 @@ impl ServerRelayList {
} = bridges;
for mut bridge_relay in relays {
bridge_relay.to_lower();
bridge_relay.convert_to_lowercase();
if let Some((country_code, city_code)) = split_location_code(&bridge_relay.location) {
if let Some(country) = countries.get_mut(country_code) {
if let Some(city) = country
@ -345,7 +345,7 @@ struct Relay {
}
impl Relay {
fn to_lower(&mut self) {
fn convert_to_lowercase(&mut self) {
self.hostname = self.hostname.to_lowercase();
self.location = self.location.to_lowercase();
}

View File

@ -130,7 +130,7 @@ impl<
> RequestService<T, F>
{
/// Constructs a new request service.
pub async fn new(
pub async fn spawn(
sni_hostname: Option<String>,
api_availability: ApiAvailabilityHandle,
address_cache: AddressCache,

View File

@ -7,26 +7,26 @@ use tokio::{fs, io};
#[error(no_from)]
pub enum Error {
#[error(display = "Failed to get path")]
PathError(#[error(source)] mullvad_paths::Error),
Path(#[error(source)] mullvad_paths::Error),
#[error(display = "Failed to remove directory {}", _0)]
RemoveDirError(String, #[error(source)] io::Error),
RemoveDir(String, #[error(source)] io::Error),
#[cfg(not(target_os = "windows"))]
#[error(display = "Failed to create directory {}", _0)]
CreateDirError(String, #[error(source)] io::Error),
CreateDir(String, #[error(source)] io::Error),
#[cfg(target_os = "windows")]
#[error(display = "Failed to get file type info")]
FileTypeError(#[error(source)] io::Error),
FileType(#[error(source)] io::Error),
#[cfg(target_os = "windows")]
#[error(display = "Failed to get dir entry")]
FileEntryError(#[error(source)] io::Error),
FileEntry(#[error(source)] io::Error),
#[cfg(target_os = "windows")]
#[error(display = "Failed to read dir entries")]
ReadDirError(#[error(source)] io::Error),
ReadDir(#[error(source)] io::Error),
}
pub async fn clear_directories() -> Result<(), Error> {
@ -35,12 +35,12 @@ pub async fn clear_directories() -> Result<(), Error> {
}
async fn clear_log_directory() -> Result<(), Error> {
let log_dir = mullvad_paths::get_log_dir().map_err(Error::PathError)?;
let log_dir = mullvad_paths::get_log_dir().map_err(Error::Path)?;
clear_directory(&log_dir).await
}
async fn clear_cache_directory() -> Result<(), Error> {
let cache_dir = mullvad_paths::cache_dir().map_err(Error::PathError)?;
let cache_dir = mullvad_paths::cache_dir().map_err(Error::Path)?;
clear_directory(&cache_dir).await
}
@ -49,22 +49,22 @@ async fn clear_directory(path: &Path) -> Result<(), Error> {
{
fs::remove_dir_all(path)
.await
.map_err(|e| Error::RemoveDirError(path.display().to_string(), e))?;
.map_err(|e| Error::RemoveDir(path.display().to_string(), e))?;
fs::create_dir_all(path)
.await
.map_err(|e| Error::CreateDirError(path.display().to_string(), e))
.map_err(|e| Error::CreateDir(path.display().to_string(), e))
}
#[cfg(target_os = "windows")]
{
let mut dir = fs::read_dir(&path).await.map_err(Error::ReadDirError)?;
let mut dir = fs::read_dir(&path).await.map_err(Error::ReadDir)?;
let mut result = Ok(());
while let Some(entry) = dir.next_entry().await.map_err(Error::FileEntryError)? {
while let Some(entry) = dir.next_entry().await.map_err(Error::FileEntry)? {
let entry_type = match entry.file_type().await {
Ok(entry_type) => entry_type,
Err(error) => {
result = result.and(Err(Error::FileTypeError(error)));
result = result.and(Err(Error::FileType(error)));
continue;
}
};
@ -74,9 +74,8 @@ async fn clear_directory(path: &Path) -> Result<(), Error> {
} else {
fs::remove_dir_all(entry.path()).await
};
result = result.and(
removal.map_err(|e| Error::RemoveDirError(entry.path().display().to_string(), e)),
);
result = result
.and(removal.map_err(|e| Error::RemoveDir(entry.path().display().to_string(), e)));
}
result
}

View File

@ -35,10 +35,10 @@ impl CurrentApiCall {
}
pub fn is_validating(&self) -> bool {
match &self.current_call {
Some(Call::Validation(_)) | Some(Call::OneshotKeyRotation(_)) => true,
_ => false,
}
matches!(
&self.current_call,
Some(Call::Validation(_)) | Some(Call::OneshotKeyRotation(_))
)
}
pub fn is_running_timed_totation(&self) -> bool {
@ -51,10 +51,7 @@ impl CurrentApiCall {
pub fn is_logging_in(&self) -> bool {
use Call::*;
match &self.current_call {
Some(Login(..)) => true,
_ => false,
}
matches!(&self.current_call, Some(Login(..)))
}
}

View File

@ -5,7 +5,7 @@ use nix::sys::signal::{sigaction, SaFlags, SigAction, SigHandler, SigSet, Signal
use std::{convert::TryFrom, sync::Once};
const INIT_ONCE: Once = Once::new();
static INIT_ONCE: Once = Once::new();
const FAULT_SIGNALS: [Signal; 5] = [
// Access to invalid memory address

View File

@ -31,7 +31,7 @@ use crate::target_state::PersistentTargetState;
use device::{PrivateAccountAndDevice, PrivateDeviceEvent};
use futures::{
channel::{mpsc, oneshot},
future::{abortable, AbortHandle, Future},
future::{abortable, AbortHandle, Future, LocalBoxFuture},
StreamExt,
};
use mullvad_relay_selector::{
@ -385,6 +385,12 @@ pub struct DaemonCommandChannel {
receiver: mpsc::UnboundedReceiver<InternalDaemonEvent>,
}
impl Default for DaemonCommandChannel {
fn default() -> Self {
Self::new()
}
}
impl DaemonCommandChannel {
pub fn new() -> Self {
let (untracked_sender, receiver) = mpsc::unbounded();
@ -472,13 +478,13 @@ impl<E> Sender<E> for DaemonEventSender<E>
where
InternalDaemonEvent: From<E>,
{
fn send(&self, event: E) -> Result<(), ()> {
fn send(&self, event: E) -> Result<(), talpid_core::mpsc::Error> {
if let Some(sender) = self.sender.upgrade() {
sender
.unbounded_send(InternalDaemonEvent::from(event))
.map_err(|_| ())
.map_err(|_| talpid_core::mpsc::Error::ChannelClosed)
} else {
Err(())
Err(talpid_core::mpsc::Error::ChannelClosed)
}
}
}
@ -684,7 +690,7 @@ where
relay_list_listener.notify_relay_list(relay_list.clone());
};
let mut relay_list_updater = RelayListUpdater::new(
let mut relay_list_updater = RelayListUpdater::spawn(
relay_selector.clone(),
api_handle.clone(),
&cache_dir,
@ -785,11 +791,11 @@ where
/// Shuts down the daemon without shutting down the underlying event listener and the shutdown
/// callbacks
fn shutdown(
fn shutdown<'a>(
self,
) -> (
L,
Vec<Pin<Box<dyn Future<Output = ()>>>>,
Vec<LocalBoxFuture<'a, ()>>,
mullvad_api::Runtime,
TunnelStateMachineHandle,
) {
@ -845,11 +851,11 @@ where
TunnelStateTransition::Disconnected => TunnelState::Disconnected,
TunnelStateTransition::Connecting(endpoint) => TunnelState::Connecting {
endpoint,
location: self.parameters_generator.get_last_location(),
location: self.parameters_generator.get_last_location().await,
},
TunnelStateTransition::Connected(endpoint) => TunnelState::Connected {
endpoint,
location: self.parameters_generator.get_last_location(),
location: self.parameters_generator.get_last_location().await,
},
TunnelStateTransition::Disconnecting(after_disconnect) => {
TunnelState::Disconnecting(after_disconnect)
@ -1184,7 +1190,7 @@ where
}
Disconnecting(..) => Self::oneshot_send(
tx,
self.parameters_generator.get_last_location(),
self.parameters_generator.get_last_location().await,
"current location",
),
Connected { location, .. } => {
@ -1703,7 +1709,8 @@ where
Self::oneshot_send(tx, Ok(()), "use_wireguard_nt response");
if settings_changed {
self.parameters_generator
.set_tunnel_options(&self.settings.tunnel_options);
.set_tunnel_options(&self.settings.tunnel_options)
.await;
self.event_listener
.notify_settings(self.settings.to_settings());
if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() {
@ -1854,7 +1861,8 @@ where
Self::oneshot_send(tx, Ok(()), "set_openvpn_mssfix response");
if settings_changed {
self.parameters_generator
.set_tunnel_options(&self.settings.tunnel_options);
.set_tunnel_options(&self.settings.tunnel_options)
.await;
self.event_listener
.notify_settings(self.settings.to_settings());
if self.get_target_tunnel_type() == Some(TunnelType::OpenVpn) {
@ -1963,7 +1971,8 @@ where
Self::oneshot_send(tx, Ok(()), "set_enable_ipv6 response");
if settings_changed {
self.parameters_generator
.set_tunnel_options(&self.settings.tunnel_options);
.set_tunnel_options(&self.settings.tunnel_options)
.await;
self.event_listener
.notify_settings(self.settings.to_settings());
log::info!("Initiating tunnel restart because the enable IPv6 setting changed");
@ -1991,7 +2000,8 @@ where
Self::oneshot_send(tx, Ok(()), "set_quantum_resistant_tunnel response");
if settings_changed {
self.parameters_generator
.set_tunnel_options(&self.settings.tunnel_options);
.set_tunnel_options(&self.settings.tunnel_options)
.await;
self.event_listener
.notify_settings(self.settings.to_settings());
if self.get_target_tunnel_type() == Some(TunnelType::Wireguard) {
@ -2021,7 +2031,8 @@ where
let resolvers =
dns::addresses_from_options(&settings.tunnel_options.dns_options);
self.parameters_generator
.set_tunnel_options(&settings.tunnel_options);
.set_tunnel_options(&settings.tunnel_options)
.await;
self.event_listener.notify_settings(settings);
self.send_tunnel_command(TunnelCommand::Dns(resolvers));
}
@ -2044,7 +2055,8 @@ where
Self::oneshot_send(tx, Ok(()), "set_wireguard_mtu response");
if settings_changed {
self.parameters_generator
.set_tunnel_options(&self.settings.tunnel_options);
.set_tunnel_options(&self.settings.tunnel_options)
.await;
self.event_listener
.notify_settings(self.settings.to_settings());
if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() {
@ -2086,7 +2098,8 @@ where
);
}
self.parameters_generator
.set_tunnel_options(&self.settings.tunnel_options);
.set_tunnel_options(&self.settings.tunnel_options)
.await;
self.event_listener
.notify_settings(self.settings.to_settings());
}

View File

@ -151,7 +151,7 @@ impl ManagementService for ManagementServiceImpl {
self.send_command_to_daemon(DaemonCommand::GetVersionInfo(tx))?;
self.wait_for_result(rx)
.await?
.ok_or(Status::not_found("no version cache"))
.ok_or_else(|| Status::not_found("no version cache"))
.map(types::AppVersionInfo::from)
.map(Response::new)
}

View File

@ -53,12 +53,12 @@ pub async fn migrate_formats(settings_dir: &Path, settings: &mut serde_json::Val
.read(true)
.open(path)
.await
.map_err(Error::ReadHistoryError)?;
.map_err(Error::ReadHistory)?;
let mut bytes = vec![];
file.read_to_end(&mut bytes)
.await
.map_err(Error::ReadHistoryError)?;
.map_err(Error::ReadHistory)?;
if is_format_v3(&bytes) {
return Ok(());
@ -92,16 +92,16 @@ fn is_format_v3(bytes: &[u8]) -> bool {
}
async fn write_format_v3(mut file: File, token: Option<AccountToken>) -> Result<()> {
file.set_len(0).await.map_err(Error::WriteHistoryError)?;
file.set_len(0).await.map_err(Error::WriteHistory)?;
file.seek(io::SeekFrom::Start(0))
.await
.map_err(Error::WriteHistoryError)?;
.map_err(Error::WriteHistory)?;
if let Some(token) = token {
file.write_all(token.as_bytes())
.await
.map_err(Error::WriteHistoryError)?;
.map_err(Error::WriteHistory)?;
}
file.sync_all().await.map_err(Error::WriteHistoryError)
file.sync_all().await.map_err(Error::WriteHistory)
}
fn try_format_v2(bytes: &[u8]) -> Result<Option<(AccountToken, serde_json::Value)>> {

View File

@ -57,31 +57,31 @@ const SETTINGS_FILE: &str = "settings.json";
#[error(no_from)]
pub enum Error {
#[error(display = "Failed to read the settings")]
ReadError(#[error(source)] io::Error),
Read(#[error(source)] io::Error),
#[error(display = "Malformed settings")]
ParseError(#[error(source)] serde_json::Error),
Parse(#[error(source)] serde_json::Error),
#[error(display = "Unable to read any version of the settings")]
NoMatchingVersion,
#[error(display = "Unable to serialize settings to JSON")]
SerializeError(#[error(source)] serde_json::Error),
Serialize(#[error(source)] serde_json::Error),
#[error(display = "Unable to open settings for writing")]
OpenError(#[error(source)] io::Error),
Open(#[error(source)] io::Error),
#[error(display = "Unable to write new settings")]
WriteError(#[error(source)] io::Error),
Write(#[error(source)] io::Error),
#[error(display = "Unable to sync settings to disk")]
SyncError(#[error(source)] io::Error),
SyncSettings(#[error(source)] io::Error),
#[error(display = "Failed to read the account history")]
ReadHistoryError(#[error(source)] io::Error),
ReadHistory(#[error(source)] io::Error),
#[error(display = "Failed to write new account history")]
WriteHistoryError(#[error(source)] io::Error),
WriteHistory(#[error(source)] io::Error),
#[error(display = "Failed to parse account history")]
ParseHistoryError,
@ -129,10 +129,10 @@ pub(crate) async fn migrate_all(
return Ok(None);
}
let settings_bytes = fs::read(&path).await.map_err(Error::ReadError)?;
let settings_bytes = fs::read(&path).await.map_err(Error::Read)?;
let mut settings: serde_json::Value =
serde_json::from_reader(&settings_bytes[..]).map_err(Error::ParseError)?;
serde_json::from_reader(&settings_bytes[..]).map_err(Error::Parse)?;
if !settings.is_object() {
return Err(Error::NoMatchingVersion);
@ -155,7 +155,7 @@ pub(crate) async fn migrate_all(
return Ok(migration_data);
}
let buffer = serde_json::to_string_pretty(&settings).map_err(Error::SerializeError)?;
let buffer = serde_json::to_string_pretty(&settings).map_err(Error::Serialize)?;
let mut options = fs::OpenOptions::new();
#[cfg(unix)]
@ -168,11 +168,11 @@ pub(crate) async fn migrate_all(
.truncate(true)
.open(&path)
.await
.map_err(Error::OpenError)?;
.map_err(Error::Open)?;
file.write_all(&buffer.into_bytes())
.await
.map_err(Error::WriteError)?;
file.sync_data().await.map_err(Error::SyncError)?;
.map_err(Error::Write)?;
file.sync_data().await.map_err(Error::SyncSettings)?;
log::debug!("Migrated settings. Wrote settings to {}", path.display());

View File

@ -1,3 +1,4 @@
#![allow(clippy::identity_op)]
use super::{Error, Result};
use mullvad_types::settings::SettingsVersion;
use std::time::Duration;

View File

@ -66,7 +66,7 @@ pub fn migrate(settings: &mut serde_json::Value) -> Result<()> {
DnsState::Default
};
let addresses = if let Some(addrs) = options.get("addresses") {
serde_json::from_value(addrs.clone()).map_err(Error::ParseError)?
serde_json::from_value(addrs.clone()).map_err(Error::Parse)?
} else {
vec![]
};

View File

@ -43,8 +43,7 @@ pub fn migrate(settings: &mut serde_json::Value) -> Result<()> {
if let Some(constraints) = wireguard_constraints {
let (port, protocol): (Constraint<u16>, TransportProtocol) =
if let Some(port) = constraints.get("port") {
let port_constraint =
serde_json::from_value(port.clone()).map_err(Error::ParseError)?;
let port_constraint = serde_json::from_value(port.clone()).map_err(Error::Parse)?;
match port_constraint {
Constraint::Any => (Constraint::Any, TransportProtocol::Udp),
Constraint::Only(port) => (Constraint::Only(port), wg_protocol_from_port(port)),
@ -77,13 +76,13 @@ pub fn migrate(settings: &mut serde_json::Value) -> Result<()> {
if let Some(constraints) = openvpn_constraints {
let port: Constraint<u16> = if let Some(port) = constraints.get("port") {
serde_json::from_value(port.clone()).map_err(Error::ParseError)?
serde_json::from_value(port.clone()).map_err(Error::Parse)?
} else {
Constraint::Any
};
let transport_constraint: Constraint<TransportProtocol> =
if let Some(protocol) = constraints.get("protocol") {
serde_json::from_value(protocol.clone()).map_err(Error::ParseError)?
serde_json::from_value(protocol.clone()).map_err(Error::Parse)?
} else {
Constraint::Any
};

View File

@ -95,7 +95,7 @@ pub(crate) async fn migrate(settings: &mut serde_json::Value) -> Result<Option<M
//
if let Some(port) = wireguard_constraints.get("port") {
let port_constraint: Constraint<TransportPort> =
serde_json::from_value(port.clone()).map_err(Error::ParseError)?;
serde_json::from_value(port.clone()).map_err(Error::Parse)?;
if let Some(transport_port) = port_constraint.option() {
let (port, obfuscation_settings) = match transport_port.protocol {
TransportProtocol::Udp => (serde_json::json!(transport_port.port), None),
@ -116,8 +116,7 @@ pub(crate) async fn migrate(settings: &mut serde_json::Value) -> Result<Option<M
let migration_data = if let Some(token) = settings.get("account_token").filter(|t| !t.is_null())
{
let token: AccountToken =
serde_json::from_value(token.clone()).map_err(Error::ParseError)?;
let token: AccountToken = serde_json::from_value(token.clone()).map_err(Error::Parse)?;
let migration_data =
if let Some(wg_data) = settings.get("wireguard").filter(|wg| !wg.is_null()) {
Some(MigrationData {

View File

@ -1,8 +1,6 @@
use std::{
future::Future,
pin::Pin,
sync::{Arc, Mutex},
};
use std::{future::Future, pin::Pin, sync::Arc};
use tokio::sync::Mutex;
use mullvad_relay_selector::{RelaySelector, SelectedBridge, SelectedObfuscator, SelectedRelay};
use mullvad_types::{
@ -32,7 +30,7 @@ pub enum Error {
NoBridgeAvailable,
#[error(display = "Failed to resolve hostname for custom relay")]
ResolveCustomHostnameError,
ResolveCustomHostname,
}
#[derive(Clone)]
@ -65,13 +63,13 @@ impl ParametersGenerator {
}
/// Sets the tunnel options to use when generating new tunnel parameters.
pub fn set_tunnel_options(&self, tunnel_options: &TunnelOptions) {
self.0.lock().unwrap().tunnel_options = tunnel_options.clone();
pub async fn set_tunnel_options(&self, tunnel_options: &TunnelOptions) {
self.0.lock().await.tunnel_options = tunnel_options.clone();
}
/// Gets the location associated with the last generated tunnel parameters.
pub fn get_last_location(&self) -> Option<GeoIpLocation> {
let inner = self.0.lock().unwrap();
pub async fn get_last_location(&self) -> Option<GeoIpLocation> {
let inner = self.0.lock().await;
let relays = inner.last_generated_relays.as_ref()?;
@ -131,7 +129,7 @@ impl InnerParametersGenerator {
.to_tunnel_parameters(self.tunnel_options.clone(), None)
.map_err(|e| {
log::error!("Failed to resolve hostname for custom tunnel config: {}", e);
Error::ResolveCustomHostnameError
Error::ResolveCustomHostname
})
}
Ok((SelectedRelay::Normal(constraints), bridge, obfuscator)) => {
@ -246,13 +244,13 @@ impl TunnelParametersGenerator for ParametersGenerator {
) -> Pin<Box<dyn Future<Output = Result<TunnelParameters, ParameterGenerationError>>>> {
let generator = self.0.clone();
Box::pin(async move {
let mut inner = generator.lock().unwrap();
let mut inner = generator.lock().await;
inner
.generate(retry_attempt)
.await
.map_err(|error| match error {
Error::NoBridgeAvailable => ParameterGenerationError::NoMatchingBridgeRelay,
Error::ResolveCustomHostnameError => {
Error::ResolveCustomHostname => {
ParameterGenerationError::CustomTunnelHostResultionError
}
error => {

View File

@ -14,6 +14,7 @@ use std::{
future::Future,
io,
path::{Path, PathBuf},
str::FromStr,
time::Duration,
};
use talpid_core::mpsc::Sender;
@ -297,10 +298,10 @@ impl VersionUpdater {
if !*IS_DEV_BUILD {
let stable_version = latest_stable
.as_ref()
.and_then(|stable| ParsedAppVersion::from_str(stable));
.and_then(|stable| ParsedAppVersion::from_str(stable).ok());
let beta_version = if show_beta {
ParsedAppVersion::from_str(latest_beta)
ParsedAppVersion::from_str(latest_beta).ok()
} else {
None
};
@ -339,10 +340,10 @@ impl VersionUpdater {
let mut check_delay = next_delay();
let mut version_check = futures::future::Fuse::terminated();
// If this is a dev build ,there's no need to pester the API for version checks.
// If this is a dev build, there's no need to pester the API for version checks.
if *IS_DEV_BUILD {
log::warn!("Not checking for updates because this is a development build");
while let Some(_) = rx.next().await {}
while rx.next().await.is_some() {}
return;
}

View File

@ -384,7 +384,7 @@ impl RelaySelector {
let mut relay_matcher = RelayMatcher {
location: location.clone(),
providers: providers.clone(),
ownership: ownership.clone(),
ownership: *ownership,
tunnel: openvpn_constraints,
};
@ -492,7 +492,7 @@ impl RelaySelector {
let mut entry_relay_matcher = RelayMatcher {
location: location.clone(),
providers: providers.clone(),
ownership: ownership.clone(),
ownership: *ownership,
tunnel: wireguard_constraints.clone().into(),
};
@ -532,7 +532,7 @@ impl RelaySelector {
.clone(),
..matcher.clone()
}
.to_wireguard_matcher();
.into_wireguard_matcher();
// Pick the entry relay first if its location constraint is a subset of the exit location.
if relay_constraints.wireguard_constraints.use_multihop {
@ -746,7 +746,7 @@ impl RelaySelector {
let bridge_constraints = InternalBridgeConstraints {
location: settings.location.clone(),
providers: settings.providers.clone(),
ownership: settings.ownership.clone(),
ownership: settings.ownership,
// FIXME: This is temporary while talpid-core only supports TCP proxies
transport_protocol: Constraint::Only(TransportProtocol::Tcp),
};
@ -791,7 +791,7 @@ impl RelaySelector {
BridgeSettings::Normal(settings) => InternalBridgeConstraints {
location: settings.location.clone(),
providers: settings.providers.clone(),
ownership: settings.ownership.clone(),
ownership: settings.ownership,
transport_protocol: Constraint::Only(TransportProtocol::Tcp),
},
BridgeSettings::Custom(_bridge_settings) => InternalBridgeConstraints {
@ -1064,7 +1064,7 @@ impl RelaySelector {
let addr_in = endpoint
.as_ref()
.map(|endpoint| endpoint.to_endpoint().address.ip())
.unwrap_or(IpAddr::from(selected_relay.ipv4_addr_in));
.unwrap_or_else(|| IpAddr::from(selected_relay.ipv4_addr_in));
log::info!("Selected relay {} at {}", selected_relay.hostname, addr_in);
endpoint.map(|endpoint| NormalSelectedRelay::new(endpoint, selected_relay.clone()))
})

View File

@ -34,7 +34,7 @@ impl From<RelayConstraints> for RelayMatcher<AnyTunnelMatcher> {
}
impl RelayMatcher<AnyTunnelMatcher> {
pub fn to_wireguard_matcher(self) -> RelayMatcher<WireguardMatcher> {
pub fn into_wireguard_matcher(self) -> RelayMatcher<WireguardMatcher> {
RelayMatcher {
tunnel: self.tunnel.wireguard,
location: self.location,

View File

@ -57,7 +57,7 @@ pub struct RelayListUpdater {
}
impl RelayListUpdater {
pub fn new(
pub fn spawn(
selector: super::RelaySelector,
api_handle: MullvadRestHandle,
cache_dir: &Path,

View File

@ -2,7 +2,7 @@ use clap::{crate_authors, crate_description, crate_name, App};
use mullvad_api::{self, proxy::ApiConnectionMode};
use mullvad_management_interface::new_rpc_client;
use mullvad_types::version::ParsedAppVersion;
use std::{path::PathBuf, process, time::Duration};
use std::{path::PathBuf, process, str::FromStr, time::Duration};
use talpid_core::{
firewall::{self, Firewall},
future_retry::{constant_interval, retry_future_n},
@ -133,7 +133,7 @@ async fn main() {
async fn is_older_version(old_version: &str) -> Result<ExitStatus, Error> {
let parsed_version =
ParsedAppVersion::from_str(old_version).ok_or(Error::ParseVersionStringError)?;
ParsedAppVersion::from_str(old_version).map_err(|_| Error::ParseVersionStringError)?;
Ok(if parsed_version < *APP_VERSION {
ExitStatus::Ok
@ -152,7 +152,7 @@ async fn prepare_restart() -> Result<(), Error> {
async fn reset_firewall() -> Result<(), Error> {
// Ensure that the daemon isn't running
if let Ok(_) = new_rpc_client().await {
if new_rpc_client().await.is_ok() {
return Err(Error::DaemonIsRunning);
}

View File

@ -625,19 +625,15 @@ impl RelaySettingsUpdate {
RelaySettingsUpdate::CustomTunnelEndpoint(endpoint) => {
endpoint.endpoint().protocol == TransportProtocol::Tcp
}
RelaySettingsUpdate::Normal(update) => {
if let Some(constraints) = &update.openvpn_constraints {
!matches!(
&constraints.port,
Constraint::Only(TransportPort {
protocol: TransportProtocol::Udp,
..
})
)
} else {
true
}
}
RelaySettingsUpdate::Normal(update) => !matches!(
&update.openvpn_constraints,
Some(OpenVpnConstraints {
port: Constraint::Only(TransportPort {
protocol: TransportProtocol::Udp,
..
})
})
),
}
}
}

View File

@ -2,7 +2,10 @@
use jnix::IntoJava;
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::cmp::{Ord, Ordering, PartialOrd};
use std::{
cmp::{Ord, Ordering, PartialOrd},
str::FromStr,
};
lazy_static::lazy_static! {
static ref STABLE_REGEX: Regex = Regex::new(r"^(\d{4})\.(\d+)$").unwrap();
@ -44,30 +47,33 @@ pub enum ParsedAppVersion {
Dev(u32, u32, Option<u32>, String),
}
impl ParsedAppVersion {
pub fn from_str(version: &str) -> Option<Self> {
impl FromStr for ParsedAppVersion {
type Err = ();
fn from_str(version: &str) -> Result<Self, Self::Err> {
let get_int = |cap: &regex::Captures<'_>, idx| cap.get(idx)?.as_str().parse().ok();
if let Some(caps) = STABLE_REGEX.captures(version) {
let year = get_int(&caps, 1)?;
let version = get_int(&caps, 2)?;
Some(Self::Stable(year, version))
let year = get_int(&caps, 1).ok_or(())?;
let version = get_int(&caps, 2).ok_or(())?;
Ok(Self::Stable(year, version))
} else if let Some(caps) = BETA_REGEX.captures(version) {
let year = get_int(&caps, 1)?;
let version = get_int(&caps, 2)?;
let beta_version = get_int(&caps, 3)?;
Some(Self::Beta(year, version, beta_version))
let year = get_int(&caps, 1).ok_or(())?;
let version = get_int(&caps, 2).ok_or(())?;
let beta_version = get_int(&caps, 3).ok_or(())?;
Ok(Self::Beta(year, version, beta_version))
} else if let Some(caps) = DEV_REGEX.captures(version) {
let year = get_int(&caps, 1)?;
let version = get_int(&caps, 2)?;
let year = get_int(&caps, 1).ok_or(())?;
let version = get_int(&caps, 2).ok_or(())?;
let beta_version = caps.get(4).map(|_| get_int(&caps, 5).unwrap());
let dev_hash = caps.get(6)?.as_str().to_string();
Some(Self::Dev(year, version, beta_version, dev_hash))
let dev_hash = caps.get(6).ok_or(())?.as_str().to_string();
Ok(Self::Dev(year, version, beta_version, dev_hash))
} else {
None
Err(())
}
}
}
impl ParsedAppVersion {
pub fn is_dev(&self) -> bool {
matches!(self, ParsedAppVersion::Dev(..))
}
@ -191,7 +197,7 @@ mod test {
];
for (input, expected_output) in tests {
assert_eq!(ParsedAppVersion::from_str(input), expected_output,);
assert_eq!(ParsedAppVersion::from_str(input).ok(), expected_output,);
}
}
}

View File

@ -22,16 +22,16 @@ pub enum Error {
RunResolvconf(#[error(source)] io::Error),
#[error(display = "Using 'resolvconf' to add a record failed: {}", stderr)]
AddRecordError { stderr: String },
AddRecord { stderr: String },
#[error(display = "Using 'resolvconf' to delete a record failed")]
DeleteRecordError,
DeleteRecord,
#[error(display = "Detected dnsmasq is runing and misconfigured")]
DnsmasqMisconfigurationError,
DnsmasqMisconfiguration,
#[error(display = "Current /etc/resolv.conf is not generated by resolvconf")]
ResolvconfNotInUseError,
ResolvconfNotInUse,
}
pub struct Resolvconf {
@ -50,15 +50,15 @@ impl Resolvconf {
// Check if resolvconf is managing DNS by /etc/resolv.conf
if !is_dnsmasq_running
&& !(Self::check_if_resolvconf_is_symlinked_correctly()
|| Self::check_if_resolvconf_was_generated())
&& !Self::check_if_resolvconf_is_symlinked_correctly()
&& !Self::check_if_resolvconf_was_generated()
{
return Err(Error::ResolvconfNotInUseError);
return Err(Error::ResolvconfNotInUse);
}
// Check if resolvconf can manage DNS via dnsmasq
if is_dnsmasq_running && Self::is_dnsmasq_configured_wrong() {
return Err(Error::DnsmasqMisconfigurationError);
return Err(Error::DnsmasqMisconfiguration);
}
Ok(Resolvconf {
@ -94,7 +94,7 @@ impl Resolvconf {
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
return Err(Error::AddRecordError { stderr });
return Err(Error::AddRecord { stderr });
}
self.record_names.insert(record_name);
@ -118,7 +118,7 @@ impl Resolvconf {
record_name,
String::from_utf8_lossy(&output.stderr)
);
result = Err(Error::DeleteRecordError);
result = Err(Error::DeleteRecord);
}
}

View File

@ -28,7 +28,7 @@ pub enum Error {
ReadResolvConf(&'static str, #[error(source)] io::Error),
#[error(display = "resolv.conf at {} could not be parsed", _0)]
ParseError(&'static str, #[error(source)] resolv_conf::ParseError),
Parse(&'static str, #[error(source)] resolv_conf::ParseError),
#[error(display = "Failed to remove stale resolv.conf backup at {}", _0)]
RemoveBackup(&'static str, #[error(source)] io::Error),
@ -179,7 +179,7 @@ fn read_config() -> Result<Config> {
let contents = fs::read_to_string(RESOLV_CONF_PATH)
.map_err(|e| Error::ReadResolvConf(RESOLV_CONF_PATH, e))?;
let config = Config::parse(&contents).map_err(|e| Error::ParseError(RESOLV_CONF_PATH, e))?;
let config = Config::parse(&contents).map_err(|e| Error::Parse(RESOLV_CONF_PATH, e))?;
Ok(config)
}
@ -198,8 +198,8 @@ fn restore_from_backup() -> Result<()> {
match fs::read_to_string(RESOLV_CONF_BACKUP_PATH) {
Ok(backup) => {
log::info!("Restoring DNS state from backup");
let config = Config::parse(&backup)
.map_err(|e| Error::ParseError(RESOLV_CONF_BACKUP_PATH, e))?;
let config =
Config::parse(&backup).map_err(|e| Error::Parse(RESOLV_CONF_BACKUP_PATH, e))?;
write_config(&config)?;

View File

@ -1,11 +1,20 @@
/// Error type for `Sender` trait.
#[derive(err_derive::Error, Debug)]
pub enum Error {
/// The underlying channel is closed.
#[error(display = "Channel is closed")]
ChannelClosed,
}
/// Abstraction over any type that can be used similarly to an `std::mpsc::Sender`.
pub trait Sender<T> {
/// Sends an item over the underlying channel, failing only if the channel is closed.
fn send(&self, item: T) -> Result<(), ()>;
fn send(&self, item: T) -> Result<(), Error>;
}
impl<E> Sender<E> for futures::channel::mpsc::UnboundedSender<E> {
fn send(&self, content: E) -> Result<(), ()> {
self.unbounded_send(content).map_err(|_| ())
fn send(&self, content: E) -> Result<(), Error> {
self.unbounded_send(content)
.map_err(|_| Error::ChannelClosed)
}
}

View File

@ -183,7 +183,7 @@ fn construct_icmpv4_packet_inner(
let checksum = internet_checksum::checksum(buffer);
(&mut buffer[ICMP_CHECKSUM_OFFSET..])
.write(&checksum)
.write_all(&checksum)
.unwrap();
true

View File

@ -87,13 +87,13 @@ pub type Result<T> = std::result::Result<T, Error>;
#[error(no_from)]
pub enum Error {
#[error(display = "Failed to open a netlink connection")]
ConnectError(#[error(source)] io::Error),
Connect(#[error(source)] io::Error),
#[error(display = "Failed to bind netlink socket")]
BindError(#[error(source)] io::Error),
Bind(#[error(source)] io::Error),
#[error(display = "Netlink error")]
NetlinkError(#[error(source)] rtnetlink::Error),
Netlink(#[error(source)] rtnetlink::Error),
#[error(display = "Route without a valid node")]
InvalidRoute,
@ -108,16 +108,16 @@ pub enum Error {
UnknownDeviceIndex(u32),
#[error(display = "Failed to get a route for the given IP address")]
GetRouteError(#[error(source)] rtnetlink::Error),
GetRoute(#[error(source)] rtnetlink::Error),
#[error(display = "No netlink response for route query")]
NoRouteError,
NoRoute,
#[error(display = "Route node was malformed")]
InvalidRouteNode,
#[error(display = "No link found")]
LinkNotFoundError,
LinkNotFound,
/// Unable to create routing table for tagged connections and packets.
#[error(display = "Cannot find a free routing table ID")]
@ -140,14 +140,11 @@ pub struct RouteManagerImpl {
impl RouteManagerImpl {
pub async fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> {
let (mut connection, handle, messages) =
rtnetlink::new_connection().map_err(Error::ConnectError)?;
rtnetlink::new_connection().map_err(Error::Connect)?;
let mgroup_flags = RTMGRP_IPV4_ROUTE | RTMGRP_IPV6_ROUTE | RTMGRP_LINK | RTMGRP_NOTIFY;
let addr = SocketAddr::new(0, mgroup_flags);
connection
.socket_mut()
.bind(&addr)
.map_err(Error::BindError)?;
connection.socket_mut().bind(&addr).map_err(Error::Bind)?;
tokio::spawn(connection);
@ -179,11 +176,11 @@ impl RouteManagerImpl {
let mut req = NetlinkMessage::from(RtnlMessage::NewRule((*rule).clone()));
req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE;
let mut response = self.handle.request(req).map_err(Error::NetlinkError)?;
let mut response = self.handle.request(req).map_err(Error::Netlink)?;
while let Some(message) = response.next().await {
if let NetlinkPayload::Error(error) = message.payload {
return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error)));
return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error)));
}
}
}
@ -236,7 +233,7 @@ impl RouteManagerImpl {
let mut req = NetlinkMessage::from(RtnlMessage::GetRule(RuleMessage::default()));
req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP;
let mut response = self.handle.request(req).map_err(Error::NetlinkError)?;
let mut response = self.handle.request(req).map_err(Error::Netlink)?;
let mut rules = vec![];
@ -246,7 +243,7 @@ impl RouteManagerImpl {
rules.push(rule);
}
NetlinkPayload::Error(error) => {
return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error)));
return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error)));
}
_ => (),
}
@ -260,12 +257,12 @@ impl RouteManagerImpl {
let mut req = NetlinkMessage::from(RtnlMessage::DelRule(rule));
req.header.flags = NLM_F_REQUEST | NLM_F_ACK;
let mut response = self.handle.request(req).map_err(Error::NetlinkError)?;
let mut response = self.handle.request(req).map_err(Error::Netlink)?;
while let Some(message) = response.next().await {
if let NetlinkPayload::Error(error) = message.payload {
if error.to_io().kind() != io::ErrorKind::NotFound {
return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(error)));
return Err(Error::Netlink(rtnetlink::Error::NetlinkError(error)));
}
}
}
@ -296,7 +293,7 @@ impl RouteManagerImpl {
) -> Result<BTreeMap<u32, NetworkInterface>> {
let mut link_map = BTreeMap::new();
let mut link_request = handle.link().get().execute();
while let Some(link) = link_request.try_next().await.map_err(Error::NetlinkError)? {
while let Some(link) = link_request.try_next().await.map_err(Error::Netlink)? {
if let Some((idx, device)) = Self::map_interface(link) {
link_map.insert(idx, device);
}
@ -543,7 +540,7 @@ impl RouteManagerImpl {
async fn delete_route_if_exists(&self, route: &Route) -> Result<()> {
if let Err(error) = self.delete_route(route).await {
if let Error::NetlinkError(rtnetlink::Error::NetlinkError(msg)) = &error {
if let Error::Netlink(rtnetlink::Error::NetlinkError(msg)) = &error {
if msg.code == -libc::ESRCH {
return Ok(());
}
@ -619,7 +616,7 @@ impl RouteManagerImpl {
.del(route_message)
.execute()
.await
.map_err(Error::NetlinkError)
.map_err(Error::Netlink)
}
async fn add_route_direct(&mut self, route: Route) -> Result<()> {
@ -693,11 +690,11 @@ impl RouteManagerImpl {
let mut req = NetlinkMessage::from(RtnlMessage::NewRoute(add_message));
req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE;
let mut response = self.handle.request(req).map_err(Error::NetlinkError)?;
let mut response = self.handle.request(req).map_err(Error::Netlink)?;
while let Some(message) = response.next().await {
if let NetlinkPayload::Error(err) = message.payload {
return Err(Error::NetlinkError(rtnetlink::Error::NetlinkError(err)));
return Err(Error::Netlink(rtnetlink::Error::NetlinkError(err)));
}
}
Ok(())
@ -759,7 +756,7 @@ impl RouteManagerImpl {
}
None => {
log::error!("No route detected when assigning the mtu to the Wireguard tunnel");
return Err(Error::NoRouteError);
return Err(Error::NoRoute);
}
}
}
@ -767,17 +764,13 @@ impl RouteManagerImpl {
"Retried {} times looking for the correct device and could not find it",
RECURSION_LIMIT
);
Err(Error::NoRouteError)
Err(Error::NoRoute)
}
async fn get_device_mtu(&self, device: String) -> Result<u16> {
let mut links = self.handle.link().get().execute();
let target_device = LinkNla::IfName(device);
while let Some(msg) = links
.try_next()
.await
.map_err(|_| Error::LinkNotFoundError)?
{
while let Some(msg) = links.try_next().await.map_err(|_| Error::LinkNotFound)? {
let found = msg.nlas.iter().any(|e| *e == target_device);
if found {
if let Some(LinkNla::Mtu(mtu)) =
@ -788,7 +781,7 @@ impl RouteManagerImpl {
}
}
}
Err(Error::LinkNotFoundError)
Err(Error::LinkNotFound)
}
async fn get_destination_route(
@ -813,11 +806,11 @@ impl RouteManagerImpl {
let mut stream = execute_route_get_request(self.handle.clone(), message.clone());
match stream.try_next().await {
Ok(Some(route_msg)) => self.parse_route_message(route_msg),
Ok(None) => Err(Error::NoRouteError),
Ok(None) => Err(Error::NoRoute),
Err(rtnetlink::Error::NetlinkError(nl_err)) if nl_err.code == -libc::ENETUNREACH => {
Ok(None)
}
Err(err) => Err(Error::GetRouteError(err)),
Err(err) => Err(Error::GetRoute(err)),
}
}
}

View File

@ -19,16 +19,19 @@ use futures::stream::Stream;
#[cfg(target_os = "linux")]
use std::net::IpAddr;
#[allow(clippy::module_inception)]
#[cfg(target_os = "macos")]
#[path = "macos.rs"]
mod imp;
#[cfg(target_os = "macos")]
pub(crate) use imp::listen_for_default_route_changes;
#[allow(clippy::module_inception)]
#[cfg(target_os = "linux")]
#[path = "linux.rs"]
mod imp;
#[allow(clippy::module_inception)]
#[cfg(target_os = "android")]
#[path = "android.rs"]
mod imp;

View File

@ -1,6 +1,6 @@
use self::tun_provider::TunProvider;
use crate::{logging, routing::RouteManagerHandle};
use futures::channel::oneshot;
use futures::{channel::oneshot, future::BoxFuture};
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
path::{Path, PathBuf},
@ -98,6 +98,20 @@ pub struct TunnelMonitor {
monitor: InternalTunnelMonitor,
}
/// Arguments for creating a tunnel.
pub struct TunnelArgs<'a, L>
where
// L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
L: (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Clone + Sync + 'static,
{
/// Resource directory.
pub resource_dir: &'a Path,
/// Callback function called when an event happens.
pub on_event: L,
/// Receiver oneshot channel for closing the tunnel.
pub tunnel_close_rx: oneshot::Receiver<()>,
}
// TODO(emilsp) move most of the openvpn tunnel details to OpenVpnTunnelMonitor
impl TunnelMonitor {
/// Creates a new `TunnelMonitor` that connects to the given remote and notifies `on_event`
@ -107,12 +121,10 @@ impl TunnelMonitor {
runtime: tokio::runtime::Handle,
tunnel_parameters: &mut TunnelParameters,
log_dir: &Option<PathBuf>,
resource_dir: &Path,
on_event: L,
tun_provider: Arc<Mutex<TunProvider>>,
route_manager: RouteManagerHandle,
retry_attempt: u32,
tunnel_close_rx: oneshot::Receiver<()>,
route_manager: RouteManagerHandle,
init_args: TunnelArgs<'_, L>,
) -> Result<Self>
where
L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
@ -129,9 +141,9 @@ impl TunnelMonitor {
TunnelParameters::OpenVpn(config) => runtime.block_on(Self::start_openvpn_tunnel(
config,
log_file,
resource_dir,
on_event,
tunnel_close_rx,
init_args.resource_dir,
init_args.on_event,
init_args.tunnel_close_rx,
#[cfg(target_os = "linux")]
route_manager,
)),
@ -142,12 +154,10 @@ impl TunnelMonitor {
runtime,
config,
log_file,
resource_dir,
on_event,
tun_provider,
route_manager,
retry_attempt,
tunnel_close_rx,
route_manager,
init_args,
),
}
}
@ -178,12 +188,10 @@ impl TunnelMonitor {
runtime: tokio::runtime::Handle,
params: &mut wireguard_types::TunnelParameters,
log: Option<PathBuf>,
resource_dir: &Path,
on_event: L,
tun_provider: Arc<Mutex<TunProvider>>,
route_manager: RouteManagerHandle,
retry_attempt: u32,
tunnel_close_rx: oneshot::Receiver<()>,
route_manager: RouteManagerHandle,
init_args: TunnelArgs<'_, L>,
) -> Result<Self>
where
L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
@ -211,12 +219,10 @@ impl TunnelMonitor {
None
},
log.as_deref(),
resource_dir,
on_event,
tun_provider,
route_manager,
retry_attempt,
tunnel_close_rx,
route_manager,
init_args,
)?;
Ok(TunnelMonitor {
monitor: InternalTunnelMonitor::Wireguard(monitor),

View File

@ -310,10 +310,19 @@ impl OpenVpnMonitor<OpenVpnCommand> {
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
let openvpn_init_args = OpenVpnTunnelInitArgs {
event_server_abort_tx: event_server_abort_tx.clone(),
event_server_abort_rx,
plugin_path,
log_path,
user_pass_file,
proxy_auth_file,
proxy_monitor,
tunnel_close_rx,
};
Self::new_internal(
cmd,
event_server_abort_tx.clone(),
event_server_abort_rx,
openvpn_init_args,
event_server::OpenvpnEventProxyImpl {
on_event,
user_pass_file_path: user_pass_file_path.clone(),
@ -324,12 +333,6 @@ impl OpenVpnMonitor<OpenVpnCommand> {
#[cfg(target_os = "linux")]
ipv6_enabled,
},
plugin_path,
log_path,
user_pass_file,
proxy_auth_file,
proxy_monitor,
tunnel_close_rx,
#[cfg(windows)]
Box::new(wintun),
)
@ -371,23 +374,36 @@ fn extract_routes(env: &HashMap<String, String>) -> Result<HashSet<RequiredRoute
Ok(routes)
}
struct OpenVpnTunnelInitArgs {
event_server_abort_tx: triggered::Trigger,
event_server_abort_rx: triggered::Listener,
plugin_path: PathBuf,
log_path: Option<PathBuf>,
user_pass_file: mktemp::TempFile,
proxy_auth_file: Option<mktemp::TempFile>,
proxy_monitor: Option<Box<dyn ProxyMonitor>>,
tunnel_close_rx: oneshot::Receiver<()>,
}
impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
async fn new_internal<L>(
mut cmd: C,
event_server_abort_tx: triggered::Trigger,
event_server_abort_rx: triggered::Listener,
init_args: OpenVpnTunnelInitArgs,
on_event: L,
plugin_path: PathBuf,
log_path: Option<PathBuf>,
user_pass_file: mktemp::TempFile,
proxy_auth_file: Option<mktemp::TempFile>,
proxy_monitor: Option<Box<dyn ProxyMonitor>>,
tunnel_close_rx: oneshot::Receiver<()>,
#[cfg(windows)] wintun: Box<dyn WintunContext>,
) -> Result<OpenVpnMonitor<C>>
where
L: event_server::OpenvpnEventProxy + Send + Sync + 'static,
{
let event_server_abort_tx = init_args.event_server_abort_tx;
let event_server_abort_rx = init_args.event_server_abort_rx;
let plugin_path = init_args.plugin_path;
let log_path = init_args.log_path;
let user_pass_file = init_args.user_pass_file;
let proxy_auth_file = init_args.proxy_auth_file;
let proxy_monitor = init_args.proxy_monitor;
let tunnel_close_rx = init_args.tunnel_close_rx;
let (server_join_handle, ipc_path) = event_server::start(on_event, event_server_abort_rx)
.await
.map_err(Error::EventDispatcherError)?;
@ -1220,23 +1236,37 @@ mod tests {
.map_err(Error::RuntimeError)
}
fn create_init_args_plugin_log(
plugin_path: PathBuf,
log_path: Option<PathBuf>,
) -> OpenVpnTunnelInitArgs {
let (_close_tx, close_rx) = oneshot::channel();
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
OpenVpnTunnelInitArgs {
event_server_abort_tx,
event_server_abort_rx,
plugin_path,
log_path,
user_pass_file: TempFile::new(),
proxy_auth_file: None,
proxy_monitor: None,
tunnel_close_rx: close_rx,
}
}
fn create_init_args() -> OpenVpnTunnelInitArgs {
create_init_args_plugin_log("".into(), None)
}
#[test]
fn sets_plugin() {
let builder = TestOpenVpnBuilder::default();
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
let (_close_tx, close_rx) = oneshot::channel();
let runtime = new_runtime().unwrap();
let openvpn_init_args = create_init_args_plugin_log("./my_test_plugin".into(), None);
let _ = runtime.block_on(OpenVpnMonitor::new_internal(
builder.clone(),
event_server_abort_tx,
event_server_abort_rx,
openvpn_init_args,
TestOpenvpnEventProxy {},
"./my_test_plugin".into(),
None,
TempFile::new(),
None,
None,
close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
));
@ -1249,20 +1279,13 @@ mod tests {
#[test]
fn sets_log() {
let builder = TestOpenVpnBuilder::default();
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
let (_close_tx, close_rx) = oneshot::channel();
let runtime = new_runtime().unwrap();
let openvpn_init_args =
create_init_args_plugin_log("".into(), Some(PathBuf::from("./my_test_log_file")));
let _ = runtime.block_on(OpenVpnMonitor::new_internal(
builder.clone(),
event_server_abort_tx,
event_server_abort_rx,
openvpn_init_args,
TestOpenvpnEventProxy {},
"".into(),
Some(PathBuf::from("./my_test_log_file")),
TempFile::new(),
None,
None,
close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
));
@ -1276,21 +1299,13 @@ mod tests {
fn exit_successfully() {
let mut builder = TestOpenVpnBuilder::default();
builder.process_handle = Some(TestProcessHandle(0));
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
let (_close_tx, close_rx) = oneshot::channel();
let runtime = new_runtime().unwrap();
let openvpn_init_args = create_init_args();
let testee = runtime
.block_on(OpenVpnMonitor::new_internal(
builder,
event_server_abort_tx,
event_server_abort_rx,
openvpn_init_args,
TestOpenvpnEventProxy {},
"".into(),
None,
TempFile::new(),
None,
None,
close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
))
@ -1302,21 +1317,13 @@ mod tests {
fn exit_error() {
let mut builder = TestOpenVpnBuilder::default();
builder.process_handle = Some(TestProcessHandle(1));
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
let (_close_tx, close_rx) = oneshot::channel();
let runtime = new_runtime().unwrap();
let openvpn_init_args = create_init_args();
let testee = runtime
.block_on(OpenVpnMonitor::new_internal(
builder,
event_server_abort_tx,
event_server_abort_rx,
openvpn_init_args,
TestOpenvpnEventProxy {},
"".into(),
None,
TempFile::new(),
None,
None,
close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
))
@ -1328,21 +1335,13 @@ mod tests {
fn wait_closed() {
let mut builder = TestOpenVpnBuilder::default();
builder.process_handle = Some(TestProcessHandle(1));
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
let (_close_tx, close_rx) = oneshot::channel();
let runtime = new_runtime().unwrap();
let openvpn_init_args = create_init_args();
let testee = runtime
.block_on(OpenVpnMonitor::new_internal(
builder,
event_server_abort_tx,
event_server_abort_rx,
openvpn_init_args,
TestOpenvpnEventProxy {},
"".into(),
None,
TempFile::new(),
None,
None,
close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
))
@ -1354,21 +1353,13 @@ mod tests {
#[test]
fn failed_process_start() {
let builder = TestOpenVpnBuilder::default();
let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
let (_close_tx, close_rx) = oneshot::channel();
let runtime = new_runtime().unwrap();
let openvpn_init_args = create_init_args();
let result = runtime
.block_on(OpenVpnMonitor::new_internal(
builder,
event_server_abort_tx,
event_server_abort_rx,
openvpn_init_args,
TestOpenvpnEventProxy {},
"".into(),
None,
TempFile::new(),
None,
None,
close_rx,
#[cfg(windows)]
Box::new(TestWintunContext {}),
))

View File

@ -22,6 +22,12 @@ pub enum Error {
/// Factory of tunnel devices on Unix systems.
pub struct UnixTunProvider;
impl Default for UnixTunProvider {
fn default() -> Self {
Self::new()
}
}
impl UnixTunProvider {
pub fn new() -> Self {
UnixTunProvider

View File

@ -112,7 +112,7 @@ pub unsafe extern "system" fn wg_go_logging_callback(
let level = match level {
WG_GO_LOG_VERBOSE => LogLevel::Verbose,
WG_GO_LOG_ERROR | _ => LogLevel::Error,
_ => LogLevel::Error,
};
log_inner(logfile, level, "wireguard-go", &managed_msg);
}
@ -121,5 +121,5 @@ pub unsafe extern "system" fn wg_go_logging_callback(
pub type WgLogLevel = u32;
// wireguard-go supports log levels 0 through 3 with 3 being the most verbose
// const WG_GO_LOG_SILENT: WgLogLevel = 0;
const WG_GO_LOG_ERROR: WgLogLevel = 1;
// const WG_GO_LOG_ERROR: WgLogLevel = 1;
const WG_GO_LOG_VERBOSE: WgLogLevel = 2;

View File

@ -1,15 +1,11 @@
use self::config::Config;
#[cfg(not(windows))]
use super::tun_provider;
use super::{tun_provider::TunProvider, TunnelEvent, TunnelMetadata};
use super::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata};
use crate::routing::{self, RequiredRoute, RouteManagerHandle};
use futures::future::{abortable, AbortHandle as FutureAbortHandle, BoxFuture, Future};
#[cfg(windows)]
use futures::{channel::mpsc, StreamExt};
use futures::{
channel::oneshot,
future::{abortable, AbortHandle as FutureAbortHandle},
Future,
};
#[cfg(target_os = "linux")]
use lazy_static::lazy_static;
#[cfg(target_os = "linux")]
@ -54,6 +50,7 @@ mod wireguard_nt;
use self::wireguard_go::WgGoTunnel;
type Result<T> = std::result::Result<T, Error>;
type EventCallback = Box<dyn (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Sync + 'static>;
/// Errors that can happen in the Wireguard tunnel monitor.
#[derive(err_derive::Error, Debug)]
@ -104,12 +101,7 @@ pub struct WireguardMonitor {
/// Tunnel implementation
tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>,
/// Callback to signal tunnel events
event_callback: Box<
dyn (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ Send
+ Sync
+ 'static,
>,
event_callback: EventCallback,
close_msg_receiver: sync_mpsc::Receiver<CloseMsg>,
pinger_stop_sender: sync_mpsc::Sender<()>,
_obfuscator: Option<ObfuscatorHandle>,
@ -208,13 +200,13 @@ impl WireguardMonitor {
mut config: Config,
psk_negotiation: Option<PublicKey>,
log_path: Option<&Path>,
resource_dir: &Path,
on_event: F,
tun_provider: Arc<Mutex<TunProvider>>,
route_manager: RouteManagerHandle,
retry_attempt: u32,
tunnel_close_rx: oneshot::Receiver<()>,
route_manager: RouteManagerHandle,
init_args: TunnelArgs<'_, F>,
) -> Result<WireguardMonitor> {
let on_event = init_args.on_event;
let endpoint_addrs: Vec<IpAddr> =
config.peers.iter().map(|peer| peer.endpoint.ip()).collect();
let (close_msg_sender, close_msg_receiver) = sync_mpsc::channel();
@ -228,7 +220,7 @@ impl WireguardMonitor {
runtime.clone(),
&Self::patch_allowed_ips(&config, psk_negotiation.is_some()),
log_path,
resource_dir,
init_args.resource_dir,
tun_provider,
#[cfg(target_os = "windows")]
setup_done_tx,
@ -351,7 +343,7 @@ impl WireguardMonitor {
});
tokio::spawn(async move {
if tunnel_close_rx.await.is_ok() {
if init_args.tunnel_close_rx.await.is_ok() {
monitor_handle.abort();
let _ = close_msg_sender.send(CloseMsg::Stop);
}

View File

@ -4,10 +4,10 @@ use super::wireguard_kernel::wg_message::{DeviceMessage, DeviceNla, PeerNla};
#[derive(err_derive::Error, Debug, PartialEq)]
pub enum Error {
#[error(display = "Failed to parse peer pubkey from string \"_0\"")]
PubKeyParseError(String, #[error(source)] hex::FromHexError),
PubKeyParse(String, #[error(source)] hex::FromHexError),
#[error(display = "Failed to parse integer from string \"_0\"")]
IntParseError(String, #[error(source)] std::num::ParseIntError),
IntParse(String, #[error(source)] std::num::ParseIntError),
#[error(display = "Device no longer exists")]
NoTunnelDevice,
@ -47,7 +47,7 @@ impl Stats {
"public_key" => {
let mut buffer = [0u8; 32];
hex::decode_to_slice(value, &mut buffer)
.map_err(|err| Error::PubKeyParseError(value.to_string(), err))?;
.map_err(|err| Error::PubKeyParse(value.to_string(), err))?;
peer = Some(buffer);
tx_bytes = None;
rx_bytes = None;
@ -57,7 +57,7 @@ impl Stats {
value
.trim()
.parse()
.map_err(|err| Error::IntParseError(value.to_string(), err))?,
.map_err(|err| Error::IntParse(value.to_string(), err))?,
);
}
"tx_bytes" => {
@ -65,7 +65,7 @@ impl Stats {
value
.trim()
.parse()
.map_err(|err| Error::IntParseError(value.to_string(), err))?,
.map_err(|err| Error::IntParse(value.to_string(), err))?,
);
}
@ -145,7 +145,7 @@ mod test {
assert_eq!(
Stats::parse_config_str(invalid_input),
Err(Error::IntParseError(invalid_str, int_err))
Err(Error::IntParse(invalid_str, int_err))
);
}
}

View File

@ -33,16 +33,16 @@ pub use nm_tunnel::NetworkManagerTunnel;
#[error(no_from)]
pub enum Error {
#[error(display = "Failed to decode netlink message")]
DecodeError(#[error(source)] DecodeError),
Decode(#[error(source)] DecodeError),
#[error(display = "Failed to execute netlink control request")]
NetlinkControlMessageError(#[error(source)] nl_message::Error),
NetlinkControlMessage(#[error(source)] nl_message::Error),
#[error(display = "Failed to open netlink socket")]
NetlinkSocketError(#[error(source)] std::io::Error),
NetlinkSocket(#[error(source)] std::io::Error),
#[error(display = "Failed to send netlink control request")]
NetlinkRequestError(#[error(source)] netlink_proto::Error<NetlinkControlMessage>),
NetlinkRequest(#[error(source)] netlink_proto::Error<NetlinkControlMessage>),
#[error(display = "WireGuard netlink interface unavailable. Is the kernel module loaded?")]
WireguardNetlinkInterfaceUnavailable,
@ -60,25 +60,25 @@ pub enum Error {
NoDevice,
#[error(display = "Failed to get config: _0")]
WgGetConfError(netlink_packet_core::error::ErrorMessage),
WgGetConf(netlink_packet_core::error::ErrorMessage),
#[error(display = "Failed to apply config: _0")]
WgSetConfError(netlink_packet_core::error::ErrorMessage),
WgSetConf(netlink_packet_core::error::ErrorMessage),
#[error(display = "Interface name too long")]
InterfaceNameError,
InterfaceName,
#[error(display = "Send request error")]
SendRequestError(#[error(source)] NetlinkError<DeviceMessage>),
SendRequest(#[error(source)] NetlinkError<DeviceMessage>),
#[error(display = "Create device error")]
NetlinkCreateDeviceError(#[error(source)] rtnetlink::Error),
NetlinkCreateDevice(#[error(source)] rtnetlink::Error),
#[error(display = "Add IP to device error")]
NetlinkSetIpError(rtnetlink::Error),
NetlinkSetIp(rtnetlink::Error),
#[error(display = "Failed to delete device")]
DeleteDeviceError(#[error(source)] rtnetlink::Error),
DeleteDevice(#[error(source)] rtnetlink::Error),
#[error(display = "NetworkManager error")]
NetworkManager(#[error(source)] nm_tunnel::Error),
@ -98,7 +98,7 @@ impl Handle {
pub async fn connect() -> Result<Self, Error> {
let message_type = Self::get_wireguard_message_type().await?;
let (conn, wireguard_connection, _messages) =
netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocketError)?;
netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocket)?;
let wg_handle = WireguardConnection {
message_type,
connection: wireguard_connection,
@ -106,7 +106,7 @@ impl Handle {
let (abortable_connection, wg_abort_handle) = abortable(conn);
tokio::spawn(abortable_connection);
let (conn, route_handle, _messages) =
rtnetlink::new_connection().map_err(Error::NetlinkSocketError)?;
rtnetlink::new_connection().map_err(Error::NetlinkSocket)?;
let (abortable_connection, route_abort_handle) = abortable(conn);
tokio::spawn(abortable_connection);
@ -120,21 +120,21 @@ impl Handle {
async fn get_wireguard_message_type() -> Result<u16, Error> {
let (conn, mut handle, _messages) =
netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocketError)?;
netlink_proto::new_connection(NETLINK_GENERIC).map_err(Error::NetlinkSocket)?;
let (conn, abort_handle) = abortable(conn);
tokio::spawn(conn);
let result = async move {
let mut message: NetlinkMessage<NetlinkControlMessage> =
NetlinkControlMessage::get_netlink_family_id(CString::new("wireguard").unwrap())
.map_err(Error::NetlinkControlMessageError)?
.map_err(Error::NetlinkControlMessage)?
.into();
message.header.flags = NLM_F_REQUEST | NLM_F_ACK;
let mut req = handle
.request(message, SocketAddr::new(0, 0))
.map_err(Error::NetlinkRequestError)?;
.map_err(Error::NetlinkRequest)?;
let response = req.next().await;
if let Some(response) = response {
if let NetlinkPayload::InnerMessage(msg) = response.payload {
@ -177,14 +177,14 @@ impl Handle {
let mut response = self
.route_handle
.request(add_request)
.map_err(Error::NetlinkCreateDeviceError)?;
.map_err(Error::NetlinkCreateDevice)?;
while let Some(response_message) = response.next().await {
if let NetlinkPayload::Error(err) = response_message.payload {
// if the device exists, verify that it's a wireguard device
if -err.code != libc::EEXIST {
return Err(Error::NetlinkCreateDeviceError(
rtnetlink::Error::NetlinkError(err),
));
return Err(Error::NetlinkCreateDevice(rtnetlink::Error::NetlinkError(
err,
)));
}
}
}
@ -208,9 +208,9 @@ impl Handle {
let mut response = self
.route_handle
.request(request)
.map_err(Error::NetlinkSetIpError)?;
.map_err(Error::NetlinkSetIp)?;
while let Some(response_message) = response.next().await {
consume_netlink_error(response_message, Error::NetlinkSetIpError)?;
consume_netlink_error(response_message, Error::NetlinkSetIp)?;
}
Ok(())
@ -226,9 +226,9 @@ impl Handle {
let mut response = self
.route_handle
.request(request)
.map_err(Error::DeleteDeviceError)?;
.map_err(Error::DeleteDevice)?;
while let Some(message) = response.next().await {
consume_netlink_error(message, Error::DeleteDeviceError)?;
consume_netlink_error(message, Error::DeleteDevice)?;
}
Ok(())
@ -269,7 +269,7 @@ impl WireguardConnection {
let mut response = self
.connection
.request(netlink_message, SocketAddr::new(0, 0))
.map_err(Error::SendRequestError)?;
.map_err(Error::SendRequest)?;
match response.next().await {
Some(received_message) => match received_message.payload {
NetlinkPayload::InnerMessage(inner) => Ok(inner),
@ -277,7 +277,7 @@ impl WireguardConnection {
if err.code == -libc::ENODEV {
Err(Error::NoDevice)
} else {
Err(Error::WgGetConfError(err))
Err(Error::WgGetConf(err))
}
}
anything_else => {
@ -297,11 +297,11 @@ impl WireguardConnection {
let mut request = self
.connection
.request(netlink_message, SocketAddr::new(0, 0))
.map_err(Error::SendRequestError)?;
.map_err(Error::SendRequest)?;
while let Some(response) = request.next().await {
if let NetlinkPayload::Error(err) = response.payload {
return Err(Error::WgSetConfError(err));
return Err(Error::WgSetConf(err));
}
}
Ok(())

View File

@ -110,9 +110,9 @@ impl DeviceMessage {
}
pub fn get_by_name(message_type: u16, name: String) -> Result<Self, Error> {
let c_name = CString::new(name).map_err(|_| Error::InterfaceNameError)?;
let c_name = CString::new(name).map_err(|_| Error::InterfaceName)?;
if c_name.as_bytes_with_nul().len() > libc::IFNAMSIZ {
return Err(Error::InterfaceNameError);
return Err(Error::InterfaceName);
}
Ok(Self {
@ -178,9 +178,7 @@ impl NetlinkDeserializable<DeviceMessage> for DeviceMessage {
let new_payload = &payload[mem::size_of::<libc::genlmsghdr>()..];
let mut nlas = vec![];
for buf in NlasIterator::new(new_payload) {
nlas.push(
DeviceNla::parse(&buf.map_err(Error::DecodeError)?).map_err(Error::DecodeError)?,
);
nlas.push(DeviceNla::parse(&buf.map_err(Error::Decode)?).map_err(Error::Decode)?);
}
Ok(DeviceMessage {
@ -391,13 +389,13 @@ impl Nla for PeerNla {
InetAddr::V4(sockaddr_in) => {
// SAFETY: `sockaddr_in` has no padding bytes
buffer
.write(unsafe { struct_as_slice(sockaddr_in) })
.write_all(unsafe { struct_as_slice(sockaddr_in) })
.expect("Buffer too small for sockaddr_in");
}
InetAddr::V6(sockaddr_in6) => {
// SAFETY: `sockaddr_in` has no padding bytes
buffer
.write(unsafe { struct_as_slice(sockaddr_in6) })
.write_all(unsafe { struct_as_slice(sockaddr_in6) })
.expect("Buffer too small for sockaddr_in6");
}
},
@ -408,7 +406,7 @@ impl Nla for PeerNla {
let timespec: &libc::timespec = last_handshake.as_ref();
// SAFETY: `timespec` has no padding bytes
buffer
.write(unsafe { struct_as_slice(timespec) })
.write_all(unsafe { struct_as_slice(timespec) })
.expect("Buffer too small for timespec");
}
RxBytes(num_bytes) | TxBytes(num_bytes) => NativeEndian::write_u64(buffer, *num_bytes),
@ -535,7 +533,7 @@ impl Nla for AllowedIpNla {
}
IpAddr(ip_addr) => {
buffer
.write(&ip_addr_to_bytes(ip_addr))
.write_all(&ip_addr_to_bytes(ip_addr))
.expect("Buffer too small for AllowedIpNla::IpAddr");
}
CidrMask(cidr_mask) => buffer[0] = *cidr_mask,

View File

@ -6,7 +6,9 @@ use super::{
use crate::{
firewall::FirewallPolicy,
routing::RouteManager,
tunnel::{self, tun_provider::TunProvider, TunnelEvent, TunnelMetadata, TunnelMonitor},
tunnel::{
self, tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata, TunnelMonitor,
},
};
use cfg_if::cfg_if;
use futures::{
@ -142,16 +144,20 @@ impl ConnectingState {
}
};
let init_args = TunnelArgs {
resource_dir: &resource_dir,
on_event: on_tunnel_event,
tunnel_close_rx,
};
let block_reason = match TunnelMonitor::start(
runtime,
&mut tunnel_parameters,
&log_dir,
&resource_dir,
on_tunnel_event,
tun_provider,
route_manager_handle,
retry_attempt,
tunnel_close_rx,
route_manager_handle,
init_args,
) {
Ok(monitor) => {
let reason = Self::wait_for_tunnel_monitor(monitor, retry_attempt);

View File

@ -132,23 +132,25 @@ pub async fn spawn(
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let weak_command_tx = Arc::downgrade(&command_tx);
let state_machine = TunnelStateMachine::new(
initial_settings,
weak_command_tx,
offline_state_listener,
let init_args = TunnelStateMachineInitArgs {
settings: initial_settings,
command_tx: weak_command_tx,
offline_state_tx: offline_state_listener,
tunnel_parameters_generator,
tun_provider,
log_dir,
resource_dir,
command_rx,
commands_rx: command_rx,
#[cfg(target_os = "windows")]
volume_update_rx,
#[cfg(target_os = "macos")]
exclusion_gid,
#[cfg(target_os = "android")]
android_context,
)
.await?;
};
let state_machine = TunnelStateMachine::new(init_args).await?;
#[cfg(windows)]
let split_tunnel = state_machine.shared_values.split_tunnel.handle();
@ -219,20 +221,35 @@ struct TunnelStateMachine {
shared_values: SharedTunnelStateValues,
}
/// Tunnel state machine initialization arguments arguments
struct TunnelStateMachineInitArgs<G: TunnelParametersGenerator> {
settings: InitialTunnelState,
command_tx: std::sync::Weak<mpsc::UnboundedSender<TunnelCommand>>,
offline_state_tx: mpsc::UnboundedSender<bool>,
tunnel_parameters_generator: G,
tun_provider: TunProvider,
log_dir: Option<PathBuf>,
resource_dir: PathBuf,
commands_rx: mpsc::UnboundedReceiver<TunnelCommand>,
#[cfg(target_os = "windows")]
volume_update_rx: mpsc::UnboundedReceiver<()>,
#[cfg(target_os = "macos")]
exclusion_gid: u32,
#[cfg(target_os = "android")]
android_context: AndroidContext,
}
impl TunnelStateMachine {
async fn new(
settings: InitialTunnelState,
command_tx: std::sync::Weak<mpsc::UnboundedSender<TunnelCommand>>,
offline_state_tx: mpsc::UnboundedSender<bool>,
tunnel_parameters_generator: impl TunnelParametersGenerator,
tun_provider: TunProvider,
log_dir: Option<PathBuf>,
resource_dir: PathBuf,
commands_rx: mpsc::UnboundedReceiver<TunnelCommand>,
#[cfg(target_os = "windows")] volume_update_rx: mpsc::UnboundedReceiver<()>,
#[cfg(target_os = "macos")] exclusion_gid: u32,
#[cfg(target_os = "android")] android_context: AndroidContext,
args: TunnelStateMachineInitArgs<impl TunnelParametersGenerator>,
) -> Result<Self, Error> {
#[cfg(target_os = "windows")]
let volume_update_rx = args.volume_update_rx;
#[cfg(target_os = "macos")]
let exclusion_gid = args.exclusion_gid;
#[cfg(target_os = "android")]
let android_context = args.android_context;
let runtime = tokio::runtime::Handle::current();
#[cfg(target_os = "macos")]
@ -242,20 +259,24 @@ impl TunnelStateMachine {
let power_mgmt_rx = crate::windows::window::PowerManagementListener::new();
#[cfg(windows)]
let split_tunnel =
split_tunnel::SplitTunnel::new(runtime.clone(), command_tx.clone(), volume_update_rx)
.map_err(Error::InitSplitTunneling)?;
let split_tunnel = split_tunnel::SplitTunnel::new(
runtime.clone(),
args.command_tx.clone(),
volume_update_rx,
)
.map_err(Error::InitSplitTunneling)?;
let args = FirewallArguments {
initial_state: if settings.block_when_disconnected || !settings.reset_firewall {
InitialFirewallState::Blocked(settings.allowed_endpoint.clone())
let fw_args = FirewallArguments {
initial_state: if args.settings.block_when_disconnected || !args.settings.reset_firewall
{
InitialFirewallState::Blocked(args.settings.allowed_endpoint.clone())
} else {
InitialFirewallState::None
},
allow_lan: settings.allow_lan,
allow_lan: args.settings.allow_lan,
};
let firewall = Firewall::from_args(args).map_err(Error::InitFirewallError)?;
let firewall = Firewall::from_args(fw_args).map_err(Error::InitFirewallError)?;
let route_manager = RouteManager::new(HashSet::new())
.await
.map_err(Error::InitRouteManagerError)?;
@ -267,20 +288,20 @@ impl TunnelStateMachine {
.handle()
.map_err(Error::InitRouteManagerError)?,
#[cfg(target_os = "macos")]
command_tx.clone(),
args.command_tx.clone(),
)
.map_err(Error::InitDnsMonitorError)?;
let (offline_tx, mut offline_rx) = mpsc::unbounded();
let initial_offline_state_tx = offline_state_tx.clone();
let initial_offline_state_tx = args.offline_state_tx.clone();
tokio::spawn(async move {
while let Some(offline) = offline_rx.next().await {
if let Some(tx) = command_tx.upgrade() {
if let Some(tx) = args.command_tx.upgrade() {
let _ = tx.unbounded_send(TunnelCommand::IsOffline(offline));
} else {
break;
}
let _ = offline_state_tx.unbounded_send(offline);
let _ = args.offline_state_tx.unbounded_send(offline);
}
});
let mut offline_monitor = offline::spawn_monitor(
@ -301,7 +322,7 @@ impl TunnelStateMachine {
#[cfg(windows)]
split_tunnel
.set_paths_sync(&settings.exclude_paths)
.set_paths_sync(&args.settings.exclude_paths)
.map_err(Error::InitSplitTunneling)?;
let mut shared_values = SharedTunnelStateValues {
@ -312,15 +333,15 @@ impl TunnelStateMachine {
dns_monitor,
route_manager,
_offline_monitor: offline_monitor,
allow_lan: settings.allow_lan,
block_when_disconnected: settings.block_when_disconnected,
allow_lan: args.settings.allow_lan,
block_when_disconnected: args.settings.block_when_disconnected,
is_offline,
dns_servers: settings.dns_servers,
allowed_endpoint: settings.allowed_endpoint,
tunnel_parameters_generator: Box::new(tunnel_parameters_generator),
tun_provider: Arc::new(Mutex::new(tun_provider)),
log_dir,
resource_dir,
dns_servers: args.settings.dns_servers,
allowed_endpoint: args.settings.allowed_endpoint,
tunnel_parameters_generator: Box::new(args.tunnel_parameters_generator),
tun_provider: Arc::new(Mutex::new(args.tun_provider)),
log_dir: args.log_dir,
resource_dir: args.resource_dir,
#[cfg(target_os = "linux")]
connectivity_check_was_enabled: None,
#[cfg(target_os = "macos")]
@ -331,11 +352,11 @@ impl TunnelStateMachine {
tokio::task::spawn_blocking(move || {
let (initial_state, _) =
DisconnectedState::enter(&mut shared_values, settings.reset_firewall);
DisconnectedState::enter(&mut shared_values, args.settings.reset_firewall);
Ok(TunnelStateMachine {
current_state: Some(initial_state),
commands: commands_rx.fuse(),
commands: args.commands_rx.fuse(),
shared_values,
})
})

View File

@ -59,6 +59,7 @@ const MAXIMUM_SUPPORTED_MINOR_VERSION: u32 = 26;
const NM_DEVICE_STATE_CHANGED: &str = "StateChanged";
pub type Result<T> = std::result::Result<T, Error>;
type NetworkSettings<'a> = HashMap<String, HashMap<String, Variant<Box<dyn RefArg + 'a>>>>;
#[derive(err_derive::Error, Debug)]
pub enum Error {
@ -447,10 +448,8 @@ impl NetworkManager {
let device = self.as_path(&device_path);
// Get the last applied connection
let (mut settings, version_id): (
HashMap<String, HashMap<String, Variant<Box<dyn RefArg>>>>,
u64,
) = device.method_call(NM_DEVICE, "GetAppliedConnection", (0u32,))?;
let (mut settings, version_id): (NetworkSettings, u64) =
device.method_call(NM_DEVICE, "GetAppliedConnection", (0u32,))?;
// Keep changed routes.
// These routes were modified outside NM, likely by RouteManager.
@ -576,7 +575,7 @@ impl NetworkManager {
}
fn update_dns_config<'a, T>(
settings: &mut HashMap<String, HashMap<String, Variant<Box<dyn RefArg + 'a>>>>,
settings: &mut NetworkSettings<'a>,
ip_protocol: &'static str,
servers: T,
) where

View File

@ -349,7 +349,7 @@ impl SystemdResolved {
.map_err(Error::DBusRpcError)
}
fn link_disable_dns_over_tls<'a, 'b: 'a>(&'a self, interface_index: u32) -> Result<()> {
fn link_disable_dns_over_tls(&self, interface_index: u32) -> Result<()> {
let link_object_path = self
.fetch_link(interface_index)
.map_err(|e| Error::GetLinkError(Box::new(e)))?;

View File

@ -95,6 +95,7 @@ fn default_wgnt_setting() -> bool {
true
}
#[allow(clippy::derivable_impls)]
impl Default for TunnelOptions {
fn default() -> Self {
Self {