Fix tun file descriptor ownership
We accidentally borrowed the file descriptor when we should have moved it. This commit adds more `OwnedFd` and friends to help handle ownership correctly. Signed-off-by: Joakim Hulthe <joakim.hulthe@mullvad.net>
This commit is contained in:
parent
f0efcc68cf
commit
b39d040d9f
@ -604,7 +604,7 @@ impl SharedTunnelStateValues {
|
||||
|
||||
#[cfg(target_os = "android")]
|
||||
pub fn bypass_socket(&mut self, fd: RawFd, tx: oneshot::Sender<()>) {
|
||||
if let Err(err) = self.tun_provider.lock().unwrap().bypass(fd) {
|
||||
if let Err(err) = self.tun_provider.lock().unwrap().bypass(&fd) {
|
||||
log::error!("Failed to bypass socket {}", err);
|
||||
}
|
||||
let _ = tx.send(());
|
||||
|
@ -197,7 +197,7 @@ impl AndroidTunProvider {
|
||||
}
|
||||
|
||||
/// Allow a socket to bypass the tunnel.
|
||||
pub fn bypass(&mut self, socket: RawFd) -> Result<(), Error> {
|
||||
pub fn bypass(&mut self, socket: &impl AsRawFd) -> Result<(), Error> {
|
||||
let env = JnixEnv::from(
|
||||
self.jvm
|
||||
.attach_current_thread_as_daemon()
|
||||
@ -212,7 +212,7 @@ impl AndroidTunProvider {
|
||||
self.object.as_obj(),
|
||||
create_tun_method,
|
||||
JavaType::Primitive(Primitive::Boolean),
|
||||
&[JValue::Int(socket)],
|
||||
&[JValue::Int(socket.as_raw_fd())],
|
||||
)
|
||||
.map_err(|cause| Error::CallMethod("bypass", cause))?;
|
||||
|
||||
@ -404,7 +404,7 @@ impl VpnServiceTun {
|
||||
}
|
||||
|
||||
/// Allow a socket to bypass the tunnel.
|
||||
pub fn bypass(&mut self, socket: RawFd) -> Result<(), Error> {
|
||||
pub fn bypass(&mut self, socket: &impl AsFd) -> Result<(), Error> {
|
||||
let env = JnixEnv::from(
|
||||
self.jvm
|
||||
.attach_current_thread_as_daemon()
|
||||
@ -419,7 +419,7 @@ impl VpnServiceTun {
|
||||
self.object.as_obj(),
|
||||
create_tun_method,
|
||||
JavaType::Primitive(Primitive::Boolean),
|
||||
&[JValue::Int(socket)],
|
||||
&[JValue::Int(socket.as_fd().as_raw_fd())],
|
||||
)
|
||||
.map_err(|cause| Error::CallMethod("bypass", cause))?;
|
||||
|
||||
|
@ -90,9 +90,7 @@ pub async fn open_boringtun_tunnel(
|
||||
let mut config = tun07::Configuration::default();
|
||||
config.raw_fd(fd);
|
||||
|
||||
boringtun_config.on_bind = Some(Box::new(move |socket| {
|
||||
tun.bypass(socket.as_raw_fd()).unwrap()
|
||||
}));
|
||||
boringtun_config.on_bind = Some(Box::new(move |socket| tun.bypass(socket).unwrap()));
|
||||
|
||||
let device = tun07::Device::new(&config).unwrap();
|
||||
tun07::AsyncDevice::new(device).unwrap()
|
||||
|
@ -117,7 +117,7 @@ async fn bypass_vpn(
|
||||
// Exclude remote obfuscation socket or bridge
|
||||
log::debug!("Excluding remote socket fd from the tunnel");
|
||||
let _ = tokio::task::spawn_blocking(move || {
|
||||
if let Err(error) = tun_provider.lock().unwrap().bypass(remote_socket_fd) {
|
||||
if let Err(error) = tun_provider.lock().unwrap().bypass(&remote_socket_fd) {
|
||||
log::error!("Failed to exclude remote socket fd: {error}");
|
||||
}
|
||||
})
|
||||
|
@ -18,8 +18,6 @@ use std::borrow::Cow;
|
||||
#[cfg(daita)]
|
||||
use std::ffi::CString;
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::io::AsRawFd;
|
||||
#[cfg(unix)]
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::{
|
||||
future::Future,
|
||||
@ -300,10 +298,10 @@ impl WgGoTunnelState {
|
||||
let socket_v6 = self.tunnel_handle.get_socket_v6();
|
||||
let mut provider = tun_provider.lock().unwrap();
|
||||
provider
|
||||
.bypass(socket_v4)
|
||||
.bypass(&socket_v4)
|
||||
.map_err(super::TunnelError::BypassError)?;
|
||||
provider
|
||||
.bypass(socket_v6)
|
||||
.bypass(&socket_v6)
|
||||
.map_err(super::TunnelError::BypassError)?;
|
||||
}
|
||||
|
||||
@ -334,7 +332,7 @@ impl WgGoTunnel {
|
||||
let handle = wireguard_go_rs::Tunnel::turn_on(
|
||||
mtu,
|
||||
&wg_config_str,
|
||||
tunnel_fd.as_raw_fd(),
|
||||
tunnel_fd,
|
||||
Some(logging::wg_go_logging_callback),
|
||||
logging_context.ordinal,
|
||||
)
|
||||
@ -529,7 +527,7 @@ impl WgGoTunnel {
|
||||
|
||||
let handle = wireguard_go_rs::Tunnel::turn_on(
|
||||
&wg_config_str,
|
||||
tunnel_fd.as_raw_fd(),
|
||||
tunnel_fd,
|
||||
Some(logging::wg_go_logging_callback),
|
||||
logging_context.ordinal,
|
||||
)
|
||||
@ -611,7 +609,7 @@ impl WgGoTunnel {
|
||||
&exit_config_str,
|
||||
&entry_config_str,
|
||||
&private_ip,
|
||||
tunnel_fd.as_raw_fd(),
|
||||
tunnel_fd,
|
||||
Some(logging::wg_go_logging_callback),
|
||||
logging_context.ordinal,
|
||||
)
|
||||
@ -658,8 +656,8 @@ impl WgGoTunnel {
|
||||
let socket_v4 = handle.get_socket_v4();
|
||||
let socket_v6 = handle.get_socket_v6();
|
||||
|
||||
tunnel_device.bypass(socket_v4)?;
|
||||
tunnel_device.bypass(socket_v6)?;
|
||||
tunnel_device.bypass(&socket_v4)?;
|
||||
tunnel_device.bypass(&socket_v6)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -8,18 +8,22 @@
|
||||
|
||||
use core::ffi::{c_char, CStr};
|
||||
use core::mem::ManuallyDrop;
|
||||
#[cfg(target_os = "windows")]
|
||||
use core::mem::MaybeUninit;
|
||||
use core::slice;
|
||||
#[cfg(target_os = "windows")]
|
||||
use std::ffi::CString;
|
||||
use talpid_types::drop_guard::on_drop;
|
||||
#[cfg(target_os = "windows")]
|
||||
use windows_sys::Win32::NetworkManagement::Ndis::NET_LUID_LH;
|
||||
use zeroize::Zeroize;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub type Fd = std::os::unix::io::RawFd;
|
||||
#[cfg(target_os = "android")]
|
||||
use std::os::fd::BorrowedFd;
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
use std::os::fd::{IntoRawFd, OwnedFd};
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
use core::mem::MaybeUninit;
|
||||
#[cfg(target_os = "windows")]
|
||||
use std::ffi::CString;
|
||||
#[cfg(target_os = "windows")]
|
||||
use windows_sys::Win32::NetworkManagement::Ndis::NET_LUID_LH;
|
||||
|
||||
pub type WgLogLevel = u32;
|
||||
|
||||
@ -84,17 +88,19 @@ impl Tunnel {
|
||||
pub fn turn_on(
|
||||
#[cfg(not(target_os = "android"))] mtu: isize,
|
||||
settings: &CStr,
|
||||
device: Fd,
|
||||
device: OwnedFd,
|
||||
logging_callback: Option<LoggingCallback>,
|
||||
logging_context: LoggingContext,
|
||||
) -> Result<Self, Error> {
|
||||
// SAFETY: pointer is valid for the lifetime of this function
|
||||
// SAFETY:
|
||||
// - pointer is valid for the lifetime of `wgTurnOn`.
|
||||
// - OwnedFd asserts that fd is open, and into_raw_fd will transfer ownership to Go.
|
||||
let code = unsafe {
|
||||
ffi::wgTurnOn(
|
||||
#[cfg(not(target_os = "android"))]
|
||||
mtu,
|
||||
settings.as_ptr(),
|
||||
device,
|
||||
device.into_raw_fd(), // Transfer ownership of the fd to Go
|
||||
logging_callback,
|
||||
logging_context,
|
||||
)
|
||||
@ -181,17 +187,19 @@ impl Tunnel {
|
||||
exit_settings: &CStr,
|
||||
entry_settings: &CStr,
|
||||
private_ip: &CStr,
|
||||
device: Fd,
|
||||
device: OwnedFd,
|
||||
logging_callback: Option<LoggingCallback>,
|
||||
logging_context: LoggingContext,
|
||||
) -> Result<Self, Error> {
|
||||
// SAFETY: pointer is valid for the lifetime of this function
|
||||
// SAFETY:
|
||||
// - pointers are valid for the lifetime of `wgTurnOnMultihop`.
|
||||
// - OwnedFd asserts that fd is open, and into_raw_fd will transfer ownership to Go.
|
||||
let code = unsafe {
|
||||
ffi::wgTurnOnMultihop(
|
||||
exit_settings.as_ptr(),
|
||||
entry_settings.as_ptr(),
|
||||
private_ip.as_ptr(),
|
||||
device,
|
||||
device.into_raw_fd(), // Transfer ownership of the fd to Go
|
||||
logging_callback,
|
||||
logging_context,
|
||||
)
|
||||
@ -279,16 +287,22 @@ impl Tunnel {
|
||||
|
||||
/// Get the file descriptor of the tunnel IPv4 socket.
|
||||
#[cfg(target_os = "android")]
|
||||
pub fn get_socket_v4(&self) -> Fd {
|
||||
// SAFETY: self.handle is a valid pointer to an active wireguard-go tunnel.
|
||||
unsafe { ffi::wgGetSocketV4(self.handle) }
|
||||
pub fn get_socket_v4(&self) -> BorrowedFd {
|
||||
// SAFETY:
|
||||
// - self.handle is a valid pointer to an active wireguard-go tunnel.
|
||||
// - file descriptor won't be closed until wgTurnOff is called,
|
||||
// which can't happen while `self` is borrowed.
|
||||
unsafe { BorrowedFd::borrow_raw(ffi::wgGetSocketV4(self.handle)) }
|
||||
}
|
||||
|
||||
/// Get the file descriptor of the tunnel IPv6 socket.
|
||||
#[cfg(target_os = "android")]
|
||||
pub fn get_socket_v6(&self) -> Fd {
|
||||
// SAFETY: self.handle is a valid pointer to an active wireguard-go tunnel.
|
||||
unsafe { ffi::wgGetSocketV6(self.handle) }
|
||||
pub fn get_socket_v6(&self) -> BorrowedFd {
|
||||
// SAFETY:
|
||||
// - self.handle is a valid pointer to an active wireguard-go tunnel.
|
||||
// - file descriptor won't be closed until wgTurnOff is called,
|
||||
// which can't happen while `self` is borrowed.
|
||||
unsafe { BorrowedFd::borrow_raw(ffi::wgGetSocketV6(self.handle)) }
|
||||
}
|
||||
}
|
||||
|
||||
@ -329,11 +343,12 @@ impl Error {
|
||||
}
|
||||
|
||||
mod ffi {
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
use super::Fd;
|
||||
use super::{LoggingCallback, LoggingContext};
|
||||
use core::ffi::{c_char, c_void};
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
use std::os::fd::RawFd;
|
||||
|
||||
unsafe extern "C" {
|
||||
/// Creates a new wireguard tunnel, uses the specific interface name, and file descriptors
|
||||
/// for the tunnel device and logging. For targets other than android, this also takes an
|
||||
@ -345,7 +360,7 @@ mod ffi {
|
||||
pub fn wgTurnOn(
|
||||
mtu: isize,
|
||||
settings: *const c_char,
|
||||
fd: Fd,
|
||||
fd: RawFd,
|
||||
logging_callback: Option<LoggingCallback>,
|
||||
logging_context: LoggingContext,
|
||||
) -> i32;
|
||||
@ -353,7 +368,7 @@ mod ffi {
|
||||
#[cfg(target_os = "android")]
|
||||
pub fn wgTurnOn(
|
||||
settings: *const c_char,
|
||||
fd: Fd,
|
||||
fd: RawFd,
|
||||
logging_callback: Option<LoggingCallback>,
|
||||
logging_context: LoggingContext,
|
||||
) -> i32;
|
||||
@ -380,7 +395,7 @@ mod ffi {
|
||||
exit_settings: *const c_char,
|
||||
entry_settings: *const c_char,
|
||||
private_ip: *const c_char,
|
||||
fd: Fd,
|
||||
fd: RawFd,
|
||||
logging_callback: Option<LoggingCallback>,
|
||||
logging_context: LoggingContext,
|
||||
) -> i32;
|
||||
@ -433,11 +448,11 @@ mod ffi {
|
||||
|
||||
/// Get the file descriptor of the tunnel IPv4 socket.
|
||||
#[cfg(target_os = "android")]
|
||||
pub fn wgGetSocketV4(handle: i32) -> Fd;
|
||||
pub fn wgGetSocketV4(handle: i32) -> RawFd;
|
||||
|
||||
/// Get the file descriptor of the tunnel IPv6 socket.
|
||||
#[cfg(target_os = "android")]
|
||||
pub fn wgGetSocketV6(handle: i32) -> Fd;
|
||||
pub fn wgGetSocketV6(handle: i32) -> RawFd;
|
||||
|
||||
/// Rebind endpoint sockets
|
||||
#[cfg(target_os = "windows")]
|
||||
|
Loading…
x
Reference in New Issue
Block a user