Fix tun file descriptor ownership

We accidentally borrowed the file descriptor when we should have moved
it. This commit adds more `OwnedFd` and friends to help handle
ownership correctly.

Signed-off-by: Joakim Hulthe <joakim.hulthe@mullvad.net>
This commit is contained in:
Joakim Hulthe 2025-06-12 14:13:58 +02:00
parent f0efcc68cf
commit b39d040d9f
No known key found for this signature in database
GPG Key ID: 1AE9299832EE47EB
6 changed files with 56 additions and 45 deletions

View File

@ -604,7 +604,7 @@ impl SharedTunnelStateValues {
#[cfg(target_os = "android")]
pub fn bypass_socket(&mut self, fd: RawFd, tx: oneshot::Sender<()>) {
if let Err(err) = self.tun_provider.lock().unwrap().bypass(fd) {
if let Err(err) = self.tun_provider.lock().unwrap().bypass(&fd) {
log::error!("Failed to bypass socket {}", err);
}
let _ = tx.send(());

View File

@ -197,7 +197,7 @@ impl AndroidTunProvider {
}
/// Allow a socket to bypass the tunnel.
pub fn bypass(&mut self, socket: RawFd) -> Result<(), Error> {
pub fn bypass(&mut self, socket: &impl AsRawFd) -> Result<(), Error> {
let env = JnixEnv::from(
self.jvm
.attach_current_thread_as_daemon()
@ -212,7 +212,7 @@ impl AndroidTunProvider {
self.object.as_obj(),
create_tun_method,
JavaType::Primitive(Primitive::Boolean),
&[JValue::Int(socket)],
&[JValue::Int(socket.as_raw_fd())],
)
.map_err(|cause| Error::CallMethod("bypass", cause))?;
@ -404,7 +404,7 @@ impl VpnServiceTun {
}
/// Allow a socket to bypass the tunnel.
pub fn bypass(&mut self, socket: RawFd) -> Result<(), Error> {
pub fn bypass(&mut self, socket: &impl AsFd) -> Result<(), Error> {
let env = JnixEnv::from(
self.jvm
.attach_current_thread_as_daemon()
@ -419,7 +419,7 @@ impl VpnServiceTun {
self.object.as_obj(),
create_tun_method,
JavaType::Primitive(Primitive::Boolean),
&[JValue::Int(socket)],
&[JValue::Int(socket.as_fd().as_raw_fd())],
)
.map_err(|cause| Error::CallMethod("bypass", cause))?;

View File

@ -90,9 +90,7 @@ pub async fn open_boringtun_tunnel(
let mut config = tun07::Configuration::default();
config.raw_fd(fd);
boringtun_config.on_bind = Some(Box::new(move |socket| {
tun.bypass(socket.as_raw_fd()).unwrap()
}));
boringtun_config.on_bind = Some(Box::new(move |socket| tun.bypass(socket).unwrap()));
let device = tun07::Device::new(&config).unwrap();
tun07::AsyncDevice::new(device).unwrap()

View File

@ -117,7 +117,7 @@ async fn bypass_vpn(
// Exclude remote obfuscation socket or bridge
log::debug!("Excluding remote socket fd from the tunnel");
let _ = tokio::task::spawn_blocking(move || {
if let Err(error) = tun_provider.lock().unwrap().bypass(remote_socket_fd) {
if let Err(error) = tun_provider.lock().unwrap().bypass(&remote_socket_fd) {
log::error!("Failed to exclude remote socket fd: {error}");
}
})

View File

@ -18,8 +18,6 @@ use std::borrow::Cow;
#[cfg(daita)]
use std::ffi::CString;
#[cfg(unix)]
use std::os::unix::io::AsRawFd;
#[cfg(unix)]
use std::sync::{Arc, Mutex};
use std::{
future::Future,
@ -300,10 +298,10 @@ impl WgGoTunnelState {
let socket_v6 = self.tunnel_handle.get_socket_v6();
let mut provider = tun_provider.lock().unwrap();
provider
.bypass(socket_v4)
.bypass(&socket_v4)
.map_err(super::TunnelError::BypassError)?;
provider
.bypass(socket_v6)
.bypass(&socket_v6)
.map_err(super::TunnelError::BypassError)?;
}
@ -334,7 +332,7 @@ impl WgGoTunnel {
let handle = wireguard_go_rs::Tunnel::turn_on(
mtu,
&wg_config_str,
tunnel_fd.as_raw_fd(),
tunnel_fd,
Some(logging::wg_go_logging_callback),
logging_context.ordinal,
)
@ -529,7 +527,7 @@ impl WgGoTunnel {
let handle = wireguard_go_rs::Tunnel::turn_on(
&wg_config_str,
tunnel_fd.as_raw_fd(),
tunnel_fd,
Some(logging::wg_go_logging_callback),
logging_context.ordinal,
)
@ -611,7 +609,7 @@ impl WgGoTunnel {
&exit_config_str,
&entry_config_str,
&private_ip,
tunnel_fd.as_raw_fd(),
tunnel_fd,
Some(logging::wg_go_logging_callback),
logging_context.ordinal,
)
@ -658,8 +656,8 @@ impl WgGoTunnel {
let socket_v4 = handle.get_socket_v4();
let socket_v6 = handle.get_socket_v6();
tunnel_device.bypass(socket_v4)?;
tunnel_device.bypass(socket_v6)?;
tunnel_device.bypass(&socket_v4)?;
tunnel_device.bypass(&socket_v6)?;
Ok(())
}

View File

@ -8,18 +8,22 @@
use core::ffi::{c_char, CStr};
use core::mem::ManuallyDrop;
#[cfg(target_os = "windows")]
use core::mem::MaybeUninit;
use core::slice;
#[cfg(target_os = "windows")]
use std::ffi::CString;
use talpid_types::drop_guard::on_drop;
#[cfg(target_os = "windows")]
use windows_sys::Win32::NetworkManagement::Ndis::NET_LUID_LH;
use zeroize::Zeroize;
#[cfg(unix)]
pub type Fd = std::os::unix::io::RawFd;
#[cfg(target_os = "android")]
use std::os::fd::BorrowedFd;
#[cfg(not(target_os = "windows"))]
use std::os::fd::{IntoRawFd, OwnedFd};
#[cfg(target_os = "windows")]
use core::mem::MaybeUninit;
#[cfg(target_os = "windows")]
use std::ffi::CString;
#[cfg(target_os = "windows")]
use windows_sys::Win32::NetworkManagement::Ndis::NET_LUID_LH;
pub type WgLogLevel = u32;
@ -84,17 +88,19 @@ impl Tunnel {
pub fn turn_on(
#[cfg(not(target_os = "android"))] mtu: isize,
settings: &CStr,
device: Fd,
device: OwnedFd,
logging_callback: Option<LoggingCallback>,
logging_context: LoggingContext,
) -> Result<Self, Error> {
// SAFETY: pointer is valid for the lifetime of this function
// SAFETY:
// - pointer is valid for the lifetime of `wgTurnOn`.
// - OwnedFd asserts that fd is open, and into_raw_fd will transfer ownership to Go.
let code = unsafe {
ffi::wgTurnOn(
#[cfg(not(target_os = "android"))]
mtu,
settings.as_ptr(),
device,
device.into_raw_fd(), // Transfer ownership of the fd to Go
logging_callback,
logging_context,
)
@ -181,17 +187,19 @@ impl Tunnel {
exit_settings: &CStr,
entry_settings: &CStr,
private_ip: &CStr,
device: Fd,
device: OwnedFd,
logging_callback: Option<LoggingCallback>,
logging_context: LoggingContext,
) -> Result<Self, Error> {
// SAFETY: pointer is valid for the lifetime of this function
// SAFETY:
// - pointers are valid for the lifetime of `wgTurnOnMultihop`.
// - OwnedFd asserts that fd is open, and into_raw_fd will transfer ownership to Go.
let code = unsafe {
ffi::wgTurnOnMultihop(
exit_settings.as_ptr(),
entry_settings.as_ptr(),
private_ip.as_ptr(),
device,
device.into_raw_fd(), // Transfer ownership of the fd to Go
logging_callback,
logging_context,
)
@ -279,16 +287,22 @@ impl Tunnel {
/// Get the file descriptor of the tunnel IPv4 socket.
#[cfg(target_os = "android")]
pub fn get_socket_v4(&self) -> Fd {
// SAFETY: self.handle is a valid pointer to an active wireguard-go tunnel.
unsafe { ffi::wgGetSocketV4(self.handle) }
pub fn get_socket_v4(&self) -> BorrowedFd {
// SAFETY:
// - self.handle is a valid pointer to an active wireguard-go tunnel.
// - file descriptor won't be closed until wgTurnOff is called,
// which can't happen while `self` is borrowed.
unsafe { BorrowedFd::borrow_raw(ffi::wgGetSocketV4(self.handle)) }
}
/// Get the file descriptor of the tunnel IPv6 socket.
#[cfg(target_os = "android")]
pub fn get_socket_v6(&self) -> Fd {
// SAFETY: self.handle is a valid pointer to an active wireguard-go tunnel.
unsafe { ffi::wgGetSocketV6(self.handle) }
pub fn get_socket_v6(&self) -> BorrowedFd {
// SAFETY:
// - self.handle is a valid pointer to an active wireguard-go tunnel.
// - file descriptor won't be closed until wgTurnOff is called,
// which can't happen while `self` is borrowed.
unsafe { BorrowedFd::borrow_raw(ffi::wgGetSocketV6(self.handle)) }
}
}
@ -329,11 +343,12 @@ impl Error {
}
mod ffi {
#[cfg(not(target_os = "windows"))]
use super::Fd;
use super::{LoggingCallback, LoggingContext};
use core::ffi::{c_char, c_void};
#[cfg(not(target_os = "windows"))]
use std::os::fd::RawFd;
unsafe extern "C" {
/// Creates a new wireguard tunnel, uses the specific interface name, and file descriptors
/// for the tunnel device and logging. For targets other than android, this also takes an
@ -345,7 +360,7 @@ mod ffi {
pub fn wgTurnOn(
mtu: isize,
settings: *const c_char,
fd: Fd,
fd: RawFd,
logging_callback: Option<LoggingCallback>,
logging_context: LoggingContext,
) -> i32;
@ -353,7 +368,7 @@ mod ffi {
#[cfg(target_os = "android")]
pub fn wgTurnOn(
settings: *const c_char,
fd: Fd,
fd: RawFd,
logging_callback: Option<LoggingCallback>,
logging_context: LoggingContext,
) -> i32;
@ -380,7 +395,7 @@ mod ffi {
exit_settings: *const c_char,
entry_settings: *const c_char,
private_ip: *const c_char,
fd: Fd,
fd: RawFd,
logging_callback: Option<LoggingCallback>,
logging_context: LoggingContext,
) -> i32;
@ -433,11 +448,11 @@ mod ffi {
/// Get the file descriptor of the tunnel IPv4 socket.
#[cfg(target_os = "android")]
pub fn wgGetSocketV4(handle: i32) -> Fd;
pub fn wgGetSocketV4(handle: i32) -> RawFd;
/// Get the file descriptor of the tunnel IPv6 socket.
#[cfg(target_os = "android")]
pub fn wgGetSocketV6(handle: i32) -> Fd;
pub fn wgGetSocketV6(handle: i32) -> RawFd;
/// Rebind endpoint sockets
#[cfg(target_os = "windows")]