Fix tun file descriptor ownership
We accidentally borrowed the file descriptor when we should have moved it. This commit adds more `OwnedFd` and friends to help handle ownership correctly. Signed-off-by: Joakim Hulthe <joakim.hulthe@mullvad.net>
This commit is contained in:
parent
f0efcc68cf
commit
b39d040d9f
@ -604,7 +604,7 @@ impl SharedTunnelStateValues {
|
|||||||
|
|
||||||
#[cfg(target_os = "android")]
|
#[cfg(target_os = "android")]
|
||||||
pub fn bypass_socket(&mut self, fd: RawFd, tx: oneshot::Sender<()>) {
|
pub fn bypass_socket(&mut self, fd: RawFd, tx: oneshot::Sender<()>) {
|
||||||
if let Err(err) = self.tun_provider.lock().unwrap().bypass(fd) {
|
if let Err(err) = self.tun_provider.lock().unwrap().bypass(&fd) {
|
||||||
log::error!("Failed to bypass socket {}", err);
|
log::error!("Failed to bypass socket {}", err);
|
||||||
}
|
}
|
||||||
let _ = tx.send(());
|
let _ = tx.send(());
|
||||||
|
@ -197,7 +197,7 @@ impl AndroidTunProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Allow a socket to bypass the tunnel.
|
/// Allow a socket to bypass the tunnel.
|
||||||
pub fn bypass(&mut self, socket: RawFd) -> Result<(), Error> {
|
pub fn bypass(&mut self, socket: &impl AsRawFd) -> Result<(), Error> {
|
||||||
let env = JnixEnv::from(
|
let env = JnixEnv::from(
|
||||||
self.jvm
|
self.jvm
|
||||||
.attach_current_thread_as_daemon()
|
.attach_current_thread_as_daemon()
|
||||||
@ -212,7 +212,7 @@ impl AndroidTunProvider {
|
|||||||
self.object.as_obj(),
|
self.object.as_obj(),
|
||||||
create_tun_method,
|
create_tun_method,
|
||||||
JavaType::Primitive(Primitive::Boolean),
|
JavaType::Primitive(Primitive::Boolean),
|
||||||
&[JValue::Int(socket)],
|
&[JValue::Int(socket.as_raw_fd())],
|
||||||
)
|
)
|
||||||
.map_err(|cause| Error::CallMethod("bypass", cause))?;
|
.map_err(|cause| Error::CallMethod("bypass", cause))?;
|
||||||
|
|
||||||
@ -404,7 +404,7 @@ impl VpnServiceTun {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Allow a socket to bypass the tunnel.
|
/// Allow a socket to bypass the tunnel.
|
||||||
pub fn bypass(&mut self, socket: RawFd) -> Result<(), Error> {
|
pub fn bypass(&mut self, socket: &impl AsFd) -> Result<(), Error> {
|
||||||
let env = JnixEnv::from(
|
let env = JnixEnv::from(
|
||||||
self.jvm
|
self.jvm
|
||||||
.attach_current_thread_as_daemon()
|
.attach_current_thread_as_daemon()
|
||||||
@ -419,7 +419,7 @@ impl VpnServiceTun {
|
|||||||
self.object.as_obj(),
|
self.object.as_obj(),
|
||||||
create_tun_method,
|
create_tun_method,
|
||||||
JavaType::Primitive(Primitive::Boolean),
|
JavaType::Primitive(Primitive::Boolean),
|
||||||
&[JValue::Int(socket)],
|
&[JValue::Int(socket.as_fd().as_raw_fd())],
|
||||||
)
|
)
|
||||||
.map_err(|cause| Error::CallMethod("bypass", cause))?;
|
.map_err(|cause| Error::CallMethod("bypass", cause))?;
|
||||||
|
|
||||||
|
@ -90,9 +90,7 @@ pub async fn open_boringtun_tunnel(
|
|||||||
let mut config = tun07::Configuration::default();
|
let mut config = tun07::Configuration::default();
|
||||||
config.raw_fd(fd);
|
config.raw_fd(fd);
|
||||||
|
|
||||||
boringtun_config.on_bind = Some(Box::new(move |socket| {
|
boringtun_config.on_bind = Some(Box::new(move |socket| tun.bypass(socket).unwrap()));
|
||||||
tun.bypass(socket.as_raw_fd()).unwrap()
|
|
||||||
}));
|
|
||||||
|
|
||||||
let device = tun07::Device::new(&config).unwrap();
|
let device = tun07::Device::new(&config).unwrap();
|
||||||
tun07::AsyncDevice::new(device).unwrap()
|
tun07::AsyncDevice::new(device).unwrap()
|
||||||
|
@ -117,7 +117,7 @@ async fn bypass_vpn(
|
|||||||
// Exclude remote obfuscation socket or bridge
|
// Exclude remote obfuscation socket or bridge
|
||||||
log::debug!("Excluding remote socket fd from the tunnel");
|
log::debug!("Excluding remote socket fd from the tunnel");
|
||||||
let _ = tokio::task::spawn_blocking(move || {
|
let _ = tokio::task::spawn_blocking(move || {
|
||||||
if let Err(error) = tun_provider.lock().unwrap().bypass(remote_socket_fd) {
|
if let Err(error) = tun_provider.lock().unwrap().bypass(&remote_socket_fd) {
|
||||||
log::error!("Failed to exclude remote socket fd: {error}");
|
log::error!("Failed to exclude remote socket fd: {error}");
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -18,8 +18,6 @@ use std::borrow::Cow;
|
|||||||
#[cfg(daita)]
|
#[cfg(daita)]
|
||||||
use std::ffi::CString;
|
use std::ffi::CString;
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
use std::os::unix::io::AsRawFd;
|
|
||||||
#[cfg(unix)]
|
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
use std::{
|
use std::{
|
||||||
future::Future,
|
future::Future,
|
||||||
@ -300,10 +298,10 @@ impl WgGoTunnelState {
|
|||||||
let socket_v6 = self.tunnel_handle.get_socket_v6();
|
let socket_v6 = self.tunnel_handle.get_socket_v6();
|
||||||
let mut provider = tun_provider.lock().unwrap();
|
let mut provider = tun_provider.lock().unwrap();
|
||||||
provider
|
provider
|
||||||
.bypass(socket_v4)
|
.bypass(&socket_v4)
|
||||||
.map_err(super::TunnelError::BypassError)?;
|
.map_err(super::TunnelError::BypassError)?;
|
||||||
provider
|
provider
|
||||||
.bypass(socket_v6)
|
.bypass(&socket_v6)
|
||||||
.map_err(super::TunnelError::BypassError)?;
|
.map_err(super::TunnelError::BypassError)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -334,7 +332,7 @@ impl WgGoTunnel {
|
|||||||
let handle = wireguard_go_rs::Tunnel::turn_on(
|
let handle = wireguard_go_rs::Tunnel::turn_on(
|
||||||
mtu,
|
mtu,
|
||||||
&wg_config_str,
|
&wg_config_str,
|
||||||
tunnel_fd.as_raw_fd(),
|
tunnel_fd,
|
||||||
Some(logging::wg_go_logging_callback),
|
Some(logging::wg_go_logging_callback),
|
||||||
logging_context.ordinal,
|
logging_context.ordinal,
|
||||||
)
|
)
|
||||||
@ -529,7 +527,7 @@ impl WgGoTunnel {
|
|||||||
|
|
||||||
let handle = wireguard_go_rs::Tunnel::turn_on(
|
let handle = wireguard_go_rs::Tunnel::turn_on(
|
||||||
&wg_config_str,
|
&wg_config_str,
|
||||||
tunnel_fd.as_raw_fd(),
|
tunnel_fd,
|
||||||
Some(logging::wg_go_logging_callback),
|
Some(logging::wg_go_logging_callback),
|
||||||
logging_context.ordinal,
|
logging_context.ordinal,
|
||||||
)
|
)
|
||||||
@ -611,7 +609,7 @@ impl WgGoTunnel {
|
|||||||
&exit_config_str,
|
&exit_config_str,
|
||||||
&entry_config_str,
|
&entry_config_str,
|
||||||
&private_ip,
|
&private_ip,
|
||||||
tunnel_fd.as_raw_fd(),
|
tunnel_fd,
|
||||||
Some(logging::wg_go_logging_callback),
|
Some(logging::wg_go_logging_callback),
|
||||||
logging_context.ordinal,
|
logging_context.ordinal,
|
||||||
)
|
)
|
||||||
@ -658,8 +656,8 @@ impl WgGoTunnel {
|
|||||||
let socket_v4 = handle.get_socket_v4();
|
let socket_v4 = handle.get_socket_v4();
|
||||||
let socket_v6 = handle.get_socket_v6();
|
let socket_v6 = handle.get_socket_v6();
|
||||||
|
|
||||||
tunnel_device.bypass(socket_v4)?;
|
tunnel_device.bypass(&socket_v4)?;
|
||||||
tunnel_device.bypass(socket_v6)?;
|
tunnel_device.bypass(&socket_v6)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -8,18 +8,22 @@
|
|||||||
|
|
||||||
use core::ffi::{c_char, CStr};
|
use core::ffi::{c_char, CStr};
|
||||||
use core::mem::ManuallyDrop;
|
use core::mem::ManuallyDrop;
|
||||||
#[cfg(target_os = "windows")]
|
|
||||||
use core::mem::MaybeUninit;
|
|
||||||
use core::slice;
|
use core::slice;
|
||||||
#[cfg(target_os = "windows")]
|
|
||||||
use std::ffi::CString;
|
|
||||||
use talpid_types::drop_guard::on_drop;
|
use talpid_types::drop_guard::on_drop;
|
||||||
#[cfg(target_os = "windows")]
|
|
||||||
use windows_sys::Win32::NetworkManagement::Ndis::NET_LUID_LH;
|
|
||||||
use zeroize::Zeroize;
|
use zeroize::Zeroize;
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(target_os = "android")]
|
||||||
pub type Fd = std::os::unix::io::RawFd;
|
use std::os::fd::BorrowedFd;
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
use std::os::fd::{IntoRawFd, OwnedFd};
|
||||||
|
|
||||||
|
#[cfg(target_os = "windows")]
|
||||||
|
use core::mem::MaybeUninit;
|
||||||
|
#[cfg(target_os = "windows")]
|
||||||
|
use std::ffi::CString;
|
||||||
|
#[cfg(target_os = "windows")]
|
||||||
|
use windows_sys::Win32::NetworkManagement::Ndis::NET_LUID_LH;
|
||||||
|
|
||||||
pub type WgLogLevel = u32;
|
pub type WgLogLevel = u32;
|
||||||
|
|
||||||
@ -84,17 +88,19 @@ impl Tunnel {
|
|||||||
pub fn turn_on(
|
pub fn turn_on(
|
||||||
#[cfg(not(target_os = "android"))] mtu: isize,
|
#[cfg(not(target_os = "android"))] mtu: isize,
|
||||||
settings: &CStr,
|
settings: &CStr,
|
||||||
device: Fd,
|
device: OwnedFd,
|
||||||
logging_callback: Option<LoggingCallback>,
|
logging_callback: Option<LoggingCallback>,
|
||||||
logging_context: LoggingContext,
|
logging_context: LoggingContext,
|
||||||
) -> Result<Self, Error> {
|
) -> Result<Self, Error> {
|
||||||
// SAFETY: pointer is valid for the lifetime of this function
|
// SAFETY:
|
||||||
|
// - pointer is valid for the lifetime of `wgTurnOn`.
|
||||||
|
// - OwnedFd asserts that fd is open, and into_raw_fd will transfer ownership to Go.
|
||||||
let code = unsafe {
|
let code = unsafe {
|
||||||
ffi::wgTurnOn(
|
ffi::wgTurnOn(
|
||||||
#[cfg(not(target_os = "android"))]
|
#[cfg(not(target_os = "android"))]
|
||||||
mtu,
|
mtu,
|
||||||
settings.as_ptr(),
|
settings.as_ptr(),
|
||||||
device,
|
device.into_raw_fd(), // Transfer ownership of the fd to Go
|
||||||
logging_callback,
|
logging_callback,
|
||||||
logging_context,
|
logging_context,
|
||||||
)
|
)
|
||||||
@ -181,17 +187,19 @@ impl Tunnel {
|
|||||||
exit_settings: &CStr,
|
exit_settings: &CStr,
|
||||||
entry_settings: &CStr,
|
entry_settings: &CStr,
|
||||||
private_ip: &CStr,
|
private_ip: &CStr,
|
||||||
device: Fd,
|
device: OwnedFd,
|
||||||
logging_callback: Option<LoggingCallback>,
|
logging_callback: Option<LoggingCallback>,
|
||||||
logging_context: LoggingContext,
|
logging_context: LoggingContext,
|
||||||
) -> Result<Self, Error> {
|
) -> Result<Self, Error> {
|
||||||
// SAFETY: pointer is valid for the lifetime of this function
|
// SAFETY:
|
||||||
|
// - pointers are valid for the lifetime of `wgTurnOnMultihop`.
|
||||||
|
// - OwnedFd asserts that fd is open, and into_raw_fd will transfer ownership to Go.
|
||||||
let code = unsafe {
|
let code = unsafe {
|
||||||
ffi::wgTurnOnMultihop(
|
ffi::wgTurnOnMultihop(
|
||||||
exit_settings.as_ptr(),
|
exit_settings.as_ptr(),
|
||||||
entry_settings.as_ptr(),
|
entry_settings.as_ptr(),
|
||||||
private_ip.as_ptr(),
|
private_ip.as_ptr(),
|
||||||
device,
|
device.into_raw_fd(), // Transfer ownership of the fd to Go
|
||||||
logging_callback,
|
logging_callback,
|
||||||
logging_context,
|
logging_context,
|
||||||
)
|
)
|
||||||
@ -279,16 +287,22 @@ impl Tunnel {
|
|||||||
|
|
||||||
/// Get the file descriptor of the tunnel IPv4 socket.
|
/// Get the file descriptor of the tunnel IPv4 socket.
|
||||||
#[cfg(target_os = "android")]
|
#[cfg(target_os = "android")]
|
||||||
pub fn get_socket_v4(&self) -> Fd {
|
pub fn get_socket_v4(&self) -> BorrowedFd {
|
||||||
// SAFETY: self.handle is a valid pointer to an active wireguard-go tunnel.
|
// SAFETY:
|
||||||
unsafe { ffi::wgGetSocketV4(self.handle) }
|
// - self.handle is a valid pointer to an active wireguard-go tunnel.
|
||||||
|
// - file descriptor won't be closed until wgTurnOff is called,
|
||||||
|
// which can't happen while `self` is borrowed.
|
||||||
|
unsafe { BorrowedFd::borrow_raw(ffi::wgGetSocketV4(self.handle)) }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the file descriptor of the tunnel IPv6 socket.
|
/// Get the file descriptor of the tunnel IPv6 socket.
|
||||||
#[cfg(target_os = "android")]
|
#[cfg(target_os = "android")]
|
||||||
pub fn get_socket_v6(&self) -> Fd {
|
pub fn get_socket_v6(&self) -> BorrowedFd {
|
||||||
// SAFETY: self.handle is a valid pointer to an active wireguard-go tunnel.
|
// SAFETY:
|
||||||
unsafe { ffi::wgGetSocketV6(self.handle) }
|
// - self.handle is a valid pointer to an active wireguard-go tunnel.
|
||||||
|
// - file descriptor won't be closed until wgTurnOff is called,
|
||||||
|
// which can't happen while `self` is borrowed.
|
||||||
|
unsafe { BorrowedFd::borrow_raw(ffi::wgGetSocketV6(self.handle)) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -329,11 +343,12 @@ impl Error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mod ffi {
|
mod ffi {
|
||||||
#[cfg(not(target_os = "windows"))]
|
|
||||||
use super::Fd;
|
|
||||||
use super::{LoggingCallback, LoggingContext};
|
use super::{LoggingCallback, LoggingContext};
|
||||||
use core::ffi::{c_char, c_void};
|
use core::ffi::{c_char, c_void};
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
use std::os::fd::RawFd;
|
||||||
|
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
/// Creates a new wireguard tunnel, uses the specific interface name, and file descriptors
|
/// Creates a new wireguard tunnel, uses the specific interface name, and file descriptors
|
||||||
/// for the tunnel device and logging. For targets other than android, this also takes an
|
/// for the tunnel device and logging. For targets other than android, this also takes an
|
||||||
@ -345,7 +360,7 @@ mod ffi {
|
|||||||
pub fn wgTurnOn(
|
pub fn wgTurnOn(
|
||||||
mtu: isize,
|
mtu: isize,
|
||||||
settings: *const c_char,
|
settings: *const c_char,
|
||||||
fd: Fd,
|
fd: RawFd,
|
||||||
logging_callback: Option<LoggingCallback>,
|
logging_callback: Option<LoggingCallback>,
|
||||||
logging_context: LoggingContext,
|
logging_context: LoggingContext,
|
||||||
) -> i32;
|
) -> i32;
|
||||||
@ -353,7 +368,7 @@ mod ffi {
|
|||||||
#[cfg(target_os = "android")]
|
#[cfg(target_os = "android")]
|
||||||
pub fn wgTurnOn(
|
pub fn wgTurnOn(
|
||||||
settings: *const c_char,
|
settings: *const c_char,
|
||||||
fd: Fd,
|
fd: RawFd,
|
||||||
logging_callback: Option<LoggingCallback>,
|
logging_callback: Option<LoggingCallback>,
|
||||||
logging_context: LoggingContext,
|
logging_context: LoggingContext,
|
||||||
) -> i32;
|
) -> i32;
|
||||||
@ -380,7 +395,7 @@ mod ffi {
|
|||||||
exit_settings: *const c_char,
|
exit_settings: *const c_char,
|
||||||
entry_settings: *const c_char,
|
entry_settings: *const c_char,
|
||||||
private_ip: *const c_char,
|
private_ip: *const c_char,
|
||||||
fd: Fd,
|
fd: RawFd,
|
||||||
logging_callback: Option<LoggingCallback>,
|
logging_callback: Option<LoggingCallback>,
|
||||||
logging_context: LoggingContext,
|
logging_context: LoggingContext,
|
||||||
) -> i32;
|
) -> i32;
|
||||||
@ -433,11 +448,11 @@ mod ffi {
|
|||||||
|
|
||||||
/// Get the file descriptor of the tunnel IPv4 socket.
|
/// Get the file descriptor of the tunnel IPv4 socket.
|
||||||
#[cfg(target_os = "android")]
|
#[cfg(target_os = "android")]
|
||||||
pub fn wgGetSocketV4(handle: i32) -> Fd;
|
pub fn wgGetSocketV4(handle: i32) -> RawFd;
|
||||||
|
|
||||||
/// Get the file descriptor of the tunnel IPv6 socket.
|
/// Get the file descriptor of the tunnel IPv6 socket.
|
||||||
#[cfg(target_os = "android")]
|
#[cfg(target_os = "android")]
|
||||||
pub fn wgGetSocketV6(handle: i32) -> Fd;
|
pub fn wgGetSocketV6(handle: i32) -> RawFd;
|
||||||
|
|
||||||
/// Rebind endpoint sockets
|
/// Rebind endpoint sockets
|
||||||
#[cfg(target_os = "windows")]
|
#[cfg(target_os = "windows")]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user