Fix tun file descriptor ownership

We accidentally borrowed the file descriptor when we should have moved
it. This commit adds more `OwnedFd` and friends to help handle
ownership correctly.

Signed-off-by: Joakim Hulthe <joakim.hulthe@mullvad.net>
This commit is contained in:
Joakim Hulthe 2025-06-12 14:13:58 +02:00
parent f0efcc68cf
commit b39d040d9f
No known key found for this signature in database
GPG Key ID: 1AE9299832EE47EB
6 changed files with 56 additions and 45 deletions

View File

@ -604,7 +604,7 @@ impl SharedTunnelStateValues {
#[cfg(target_os = "android")] #[cfg(target_os = "android")]
pub fn bypass_socket(&mut self, fd: RawFd, tx: oneshot::Sender<()>) { pub fn bypass_socket(&mut self, fd: RawFd, tx: oneshot::Sender<()>) {
if let Err(err) = self.tun_provider.lock().unwrap().bypass(fd) { if let Err(err) = self.tun_provider.lock().unwrap().bypass(&fd) {
log::error!("Failed to bypass socket {}", err); log::error!("Failed to bypass socket {}", err);
} }
let _ = tx.send(()); let _ = tx.send(());

View File

@ -197,7 +197,7 @@ impl AndroidTunProvider {
} }
/// Allow a socket to bypass the tunnel. /// Allow a socket to bypass the tunnel.
pub fn bypass(&mut self, socket: RawFd) -> Result<(), Error> { pub fn bypass(&mut self, socket: &impl AsRawFd) -> Result<(), Error> {
let env = JnixEnv::from( let env = JnixEnv::from(
self.jvm self.jvm
.attach_current_thread_as_daemon() .attach_current_thread_as_daemon()
@ -212,7 +212,7 @@ impl AndroidTunProvider {
self.object.as_obj(), self.object.as_obj(),
create_tun_method, create_tun_method,
JavaType::Primitive(Primitive::Boolean), JavaType::Primitive(Primitive::Boolean),
&[JValue::Int(socket)], &[JValue::Int(socket.as_raw_fd())],
) )
.map_err(|cause| Error::CallMethod("bypass", cause))?; .map_err(|cause| Error::CallMethod("bypass", cause))?;
@ -404,7 +404,7 @@ impl VpnServiceTun {
} }
/// Allow a socket to bypass the tunnel. /// Allow a socket to bypass the tunnel.
pub fn bypass(&mut self, socket: RawFd) -> Result<(), Error> { pub fn bypass(&mut self, socket: &impl AsFd) -> Result<(), Error> {
let env = JnixEnv::from( let env = JnixEnv::from(
self.jvm self.jvm
.attach_current_thread_as_daemon() .attach_current_thread_as_daemon()
@ -419,7 +419,7 @@ impl VpnServiceTun {
self.object.as_obj(), self.object.as_obj(),
create_tun_method, create_tun_method,
JavaType::Primitive(Primitive::Boolean), JavaType::Primitive(Primitive::Boolean),
&[JValue::Int(socket)], &[JValue::Int(socket.as_fd().as_raw_fd())],
) )
.map_err(|cause| Error::CallMethod("bypass", cause))?; .map_err(|cause| Error::CallMethod("bypass", cause))?;

View File

@ -90,9 +90,7 @@ pub async fn open_boringtun_tunnel(
let mut config = tun07::Configuration::default(); let mut config = tun07::Configuration::default();
config.raw_fd(fd); config.raw_fd(fd);
boringtun_config.on_bind = Some(Box::new(move |socket| { boringtun_config.on_bind = Some(Box::new(move |socket| tun.bypass(socket).unwrap()));
tun.bypass(socket.as_raw_fd()).unwrap()
}));
let device = tun07::Device::new(&config).unwrap(); let device = tun07::Device::new(&config).unwrap();
tun07::AsyncDevice::new(device).unwrap() tun07::AsyncDevice::new(device).unwrap()

View File

@ -117,7 +117,7 @@ async fn bypass_vpn(
// Exclude remote obfuscation socket or bridge // Exclude remote obfuscation socket or bridge
log::debug!("Excluding remote socket fd from the tunnel"); log::debug!("Excluding remote socket fd from the tunnel");
let _ = tokio::task::spawn_blocking(move || { let _ = tokio::task::spawn_blocking(move || {
if let Err(error) = tun_provider.lock().unwrap().bypass(remote_socket_fd) { if let Err(error) = tun_provider.lock().unwrap().bypass(&remote_socket_fd) {
log::error!("Failed to exclude remote socket fd: {error}"); log::error!("Failed to exclude remote socket fd: {error}");
} }
}) })

View File

@ -18,8 +18,6 @@ use std::borrow::Cow;
#[cfg(daita)] #[cfg(daita)]
use std::ffi::CString; use std::ffi::CString;
#[cfg(unix)] #[cfg(unix)]
use std::os::unix::io::AsRawFd;
#[cfg(unix)]
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::{ use std::{
future::Future, future::Future,
@ -300,10 +298,10 @@ impl WgGoTunnelState {
let socket_v6 = self.tunnel_handle.get_socket_v6(); let socket_v6 = self.tunnel_handle.get_socket_v6();
let mut provider = tun_provider.lock().unwrap(); let mut provider = tun_provider.lock().unwrap();
provider provider
.bypass(socket_v4) .bypass(&socket_v4)
.map_err(super::TunnelError::BypassError)?; .map_err(super::TunnelError::BypassError)?;
provider provider
.bypass(socket_v6) .bypass(&socket_v6)
.map_err(super::TunnelError::BypassError)?; .map_err(super::TunnelError::BypassError)?;
} }
@ -334,7 +332,7 @@ impl WgGoTunnel {
let handle = wireguard_go_rs::Tunnel::turn_on( let handle = wireguard_go_rs::Tunnel::turn_on(
mtu, mtu,
&wg_config_str, &wg_config_str,
tunnel_fd.as_raw_fd(), tunnel_fd,
Some(logging::wg_go_logging_callback), Some(logging::wg_go_logging_callback),
logging_context.ordinal, logging_context.ordinal,
) )
@ -529,7 +527,7 @@ impl WgGoTunnel {
let handle = wireguard_go_rs::Tunnel::turn_on( let handle = wireguard_go_rs::Tunnel::turn_on(
&wg_config_str, &wg_config_str,
tunnel_fd.as_raw_fd(), tunnel_fd,
Some(logging::wg_go_logging_callback), Some(logging::wg_go_logging_callback),
logging_context.ordinal, logging_context.ordinal,
) )
@ -611,7 +609,7 @@ impl WgGoTunnel {
&exit_config_str, &exit_config_str,
&entry_config_str, &entry_config_str,
&private_ip, &private_ip,
tunnel_fd.as_raw_fd(), tunnel_fd,
Some(logging::wg_go_logging_callback), Some(logging::wg_go_logging_callback),
logging_context.ordinal, logging_context.ordinal,
) )
@ -658,8 +656,8 @@ impl WgGoTunnel {
let socket_v4 = handle.get_socket_v4(); let socket_v4 = handle.get_socket_v4();
let socket_v6 = handle.get_socket_v6(); let socket_v6 = handle.get_socket_v6();
tunnel_device.bypass(socket_v4)?; tunnel_device.bypass(&socket_v4)?;
tunnel_device.bypass(socket_v6)?; tunnel_device.bypass(&socket_v6)?;
Ok(()) Ok(())
} }

View File

@ -8,18 +8,22 @@
use core::ffi::{c_char, CStr}; use core::ffi::{c_char, CStr};
use core::mem::ManuallyDrop; use core::mem::ManuallyDrop;
#[cfg(target_os = "windows")]
use core::mem::MaybeUninit;
use core::slice; use core::slice;
#[cfg(target_os = "windows")]
use std::ffi::CString;
use talpid_types::drop_guard::on_drop; use talpid_types::drop_guard::on_drop;
#[cfg(target_os = "windows")]
use windows_sys::Win32::NetworkManagement::Ndis::NET_LUID_LH;
use zeroize::Zeroize; use zeroize::Zeroize;
#[cfg(unix)] #[cfg(target_os = "android")]
pub type Fd = std::os::unix::io::RawFd; use std::os::fd::BorrowedFd;
#[cfg(not(target_os = "windows"))]
use std::os::fd::{IntoRawFd, OwnedFd};
#[cfg(target_os = "windows")]
use core::mem::MaybeUninit;
#[cfg(target_os = "windows")]
use std::ffi::CString;
#[cfg(target_os = "windows")]
use windows_sys::Win32::NetworkManagement::Ndis::NET_LUID_LH;
pub type WgLogLevel = u32; pub type WgLogLevel = u32;
@ -84,17 +88,19 @@ impl Tunnel {
pub fn turn_on( pub fn turn_on(
#[cfg(not(target_os = "android"))] mtu: isize, #[cfg(not(target_os = "android"))] mtu: isize,
settings: &CStr, settings: &CStr,
device: Fd, device: OwnedFd,
logging_callback: Option<LoggingCallback>, logging_callback: Option<LoggingCallback>,
logging_context: LoggingContext, logging_context: LoggingContext,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
// SAFETY: pointer is valid for the lifetime of this function // SAFETY:
// - pointer is valid for the lifetime of `wgTurnOn`.
// - OwnedFd asserts that fd is open, and into_raw_fd will transfer ownership to Go.
let code = unsafe { let code = unsafe {
ffi::wgTurnOn( ffi::wgTurnOn(
#[cfg(not(target_os = "android"))] #[cfg(not(target_os = "android"))]
mtu, mtu,
settings.as_ptr(), settings.as_ptr(),
device, device.into_raw_fd(), // Transfer ownership of the fd to Go
logging_callback, logging_callback,
logging_context, logging_context,
) )
@ -181,17 +187,19 @@ impl Tunnel {
exit_settings: &CStr, exit_settings: &CStr,
entry_settings: &CStr, entry_settings: &CStr,
private_ip: &CStr, private_ip: &CStr,
device: Fd, device: OwnedFd,
logging_callback: Option<LoggingCallback>, logging_callback: Option<LoggingCallback>,
logging_context: LoggingContext, logging_context: LoggingContext,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
// SAFETY: pointer is valid for the lifetime of this function // SAFETY:
// - pointers are valid for the lifetime of `wgTurnOnMultihop`.
// - OwnedFd asserts that fd is open, and into_raw_fd will transfer ownership to Go.
let code = unsafe { let code = unsafe {
ffi::wgTurnOnMultihop( ffi::wgTurnOnMultihop(
exit_settings.as_ptr(), exit_settings.as_ptr(),
entry_settings.as_ptr(), entry_settings.as_ptr(),
private_ip.as_ptr(), private_ip.as_ptr(),
device, device.into_raw_fd(), // Transfer ownership of the fd to Go
logging_callback, logging_callback,
logging_context, logging_context,
) )
@ -279,16 +287,22 @@ impl Tunnel {
/// Get the file descriptor of the tunnel IPv4 socket. /// Get the file descriptor of the tunnel IPv4 socket.
#[cfg(target_os = "android")] #[cfg(target_os = "android")]
pub fn get_socket_v4(&self) -> Fd { pub fn get_socket_v4(&self) -> BorrowedFd {
// SAFETY: self.handle is a valid pointer to an active wireguard-go tunnel. // SAFETY:
unsafe { ffi::wgGetSocketV4(self.handle) } // - self.handle is a valid pointer to an active wireguard-go tunnel.
// - file descriptor won't be closed until wgTurnOff is called,
// which can't happen while `self` is borrowed.
unsafe { BorrowedFd::borrow_raw(ffi::wgGetSocketV4(self.handle)) }
} }
/// Get the file descriptor of the tunnel IPv6 socket. /// Get the file descriptor of the tunnel IPv6 socket.
#[cfg(target_os = "android")] #[cfg(target_os = "android")]
pub fn get_socket_v6(&self) -> Fd { pub fn get_socket_v6(&self) -> BorrowedFd {
// SAFETY: self.handle is a valid pointer to an active wireguard-go tunnel. // SAFETY:
unsafe { ffi::wgGetSocketV6(self.handle) } // - self.handle is a valid pointer to an active wireguard-go tunnel.
// - file descriptor won't be closed until wgTurnOff is called,
// which can't happen while `self` is borrowed.
unsafe { BorrowedFd::borrow_raw(ffi::wgGetSocketV6(self.handle)) }
} }
} }
@ -329,11 +343,12 @@ impl Error {
} }
mod ffi { mod ffi {
#[cfg(not(target_os = "windows"))]
use super::Fd;
use super::{LoggingCallback, LoggingContext}; use super::{LoggingCallback, LoggingContext};
use core::ffi::{c_char, c_void}; use core::ffi::{c_char, c_void};
#[cfg(not(target_os = "windows"))]
use std::os::fd::RawFd;
unsafe extern "C" { unsafe extern "C" {
/// Creates a new wireguard tunnel, uses the specific interface name, and file descriptors /// Creates a new wireguard tunnel, uses the specific interface name, and file descriptors
/// for the tunnel device and logging. For targets other than android, this also takes an /// for the tunnel device and logging. For targets other than android, this also takes an
@ -345,7 +360,7 @@ mod ffi {
pub fn wgTurnOn( pub fn wgTurnOn(
mtu: isize, mtu: isize,
settings: *const c_char, settings: *const c_char,
fd: Fd, fd: RawFd,
logging_callback: Option<LoggingCallback>, logging_callback: Option<LoggingCallback>,
logging_context: LoggingContext, logging_context: LoggingContext,
) -> i32; ) -> i32;
@ -353,7 +368,7 @@ mod ffi {
#[cfg(target_os = "android")] #[cfg(target_os = "android")]
pub fn wgTurnOn( pub fn wgTurnOn(
settings: *const c_char, settings: *const c_char,
fd: Fd, fd: RawFd,
logging_callback: Option<LoggingCallback>, logging_callback: Option<LoggingCallback>,
logging_context: LoggingContext, logging_context: LoggingContext,
) -> i32; ) -> i32;
@ -380,7 +395,7 @@ mod ffi {
exit_settings: *const c_char, exit_settings: *const c_char,
entry_settings: *const c_char, entry_settings: *const c_char,
private_ip: *const c_char, private_ip: *const c_char,
fd: Fd, fd: RawFd,
logging_callback: Option<LoggingCallback>, logging_callback: Option<LoggingCallback>,
logging_context: LoggingContext, logging_context: LoggingContext,
) -> i32; ) -> i32;
@ -433,11 +448,11 @@ mod ffi {
/// Get the file descriptor of the tunnel IPv4 socket. /// Get the file descriptor of the tunnel IPv4 socket.
#[cfg(target_os = "android")] #[cfg(target_os = "android")]
pub fn wgGetSocketV4(handle: i32) -> Fd; pub fn wgGetSocketV4(handle: i32) -> RawFd;
/// Get the file descriptor of the tunnel IPv6 socket. /// Get the file descriptor of the tunnel IPv6 socket.
#[cfg(target_os = "android")] #[cfg(target_os = "android")]
pub fn wgGetSocketV6(handle: i32) -> Fd; pub fn wgGetSocketV6(handle: i32) -> RawFd;
/// Rebind endpoint sockets /// Rebind endpoint sockets
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]