[PATCH 3/4] rust: pci: fix unrestricted &mut pci::Device

From: Danilo Krummrich
Date: Wed Mar 12 2025 - 22:18:36 EST


As by now, pci::Device is implemented as:

#[derive(Clone)]
pub struct Device(ARef<device::Device>);

This may be convenient, but has the implication that drivers can call
device methods that require a mutable reference concurrently at any
point of time.

Instead define pci::Device as

pub struct Device<Ctx: DeviceContext = Normal>(
Opaque<bindings::pci_dev>,
PhantomData<Ctx>,
);

and manually implement the AlwaysRefCounted trait.

With this we can implement methods that should only be called from
bus callbacks (such as probe()) for pci::Device<Core>. Consequently, we
make this type accessible in bus callbacks only.

Arbitrary references taken by the driver are still of type
ARef<pci::Device> and hence don't provide access to methods that are
reserved for bus callbacks.

Fixes: 1bd8b6b2c5d3 ("rust: pci: add basic PCI device / driver abstractions")
Signed-off-by: Danilo Krummrich <dakr@xxxxxxxxxx>
---
drivers/gpu/nova-core/driver.rs | 4 +-
rust/kernel/pci.rs | 126 ++++++++++++++++++++------------
samples/rust/rust_driver_pci.rs | 8 +-
3 files changed, 85 insertions(+), 53 deletions(-)

diff --git a/drivers/gpu/nova-core/driver.rs b/drivers/gpu/nova-core/driver.rs
index 63c19f140fbd..a08fb6599267 100644
--- a/drivers/gpu/nova-core/driver.rs
+++ b/drivers/gpu/nova-core/driver.rs
@@ -1,6 +1,6 @@
// SPDX-License-Identifier: GPL-2.0

-use kernel::{bindings, c_str, pci, prelude::*};
+use kernel::{bindings, c_str, device::Core, pci, prelude::*};

use crate::gpu::Gpu;

@@ -27,7 +27,7 @@ impl pci::Driver for NovaCore {
type IdInfo = ();
const ID_TABLE: pci::IdTable<Self::IdInfo> = &PCI_TABLE;

- fn probe(pdev: &mut pci::Device, _info: &Self::IdInfo) -> Result<Pin<KBox<Self>>> {
+ fn probe(pdev: &pci::Device<Core>, _info: &Self::IdInfo) -> Result<Pin<KBox<Self>>> {
dev_dbg!(pdev.as_ref(), "Probe Nova Core GPU driver.\n");

pdev.enable_device_mem()?;
diff --git a/rust/kernel/pci.rs b/rust/kernel/pci.rs
index 386484dcf36e..6357b4ff8d65 100644
--- a/rust/kernel/pci.rs
+++ b/rust/kernel/pci.rs
@@ -6,7 +6,7 @@

use crate::{
alloc::flags::*,
- bindings, container_of, device,
+ bindings, device,
device_id::RawDeviceId,
devres::Devres,
driver,
@@ -17,7 +17,11 @@
types::{ARef, ForeignOwnable, Opaque},
ThisModule,
};
-use core::{ops::Deref, ptr::addr_of_mut};
+use core::{
+ marker::PhantomData,
+ ops::Deref,
+ ptr::{addr_of_mut, NonNull},
+};
use kernel::prelude::*;

/// An adapter for the registration of PCI drivers.
@@ -60,17 +64,16 @@ extern "C" fn probe_callback(
) -> kernel::ffi::c_int {
// SAFETY: The PCI bus only ever calls the probe callback with a valid pointer to a
// `struct pci_dev`.
- let dev = unsafe { device::Device::get_device(addr_of_mut!((*pdev).dev)) };
- // SAFETY: `dev` is guaranteed to be embedded in a valid `struct pci_dev` by the call
- // above.
- let mut pdev = unsafe { Device::from_dev(dev) };
+ //
+ // INVARIANT: `pdev` is valid for the duration of `probe_callback()`.
+ let pdev = unsafe { &*pdev.cast::<Device<device::Core>>() };

// SAFETY: `DeviceId` is a `#[repr(transparent)` wrapper of `struct pci_device_id` and
// does not add additional invariants, so it's safe to transmute.
let id = unsafe { &*id.cast::<DeviceId>() };
let info = T::ID_TABLE.info(id.index());

- match T::probe(&mut pdev, info) {
+ match T::probe(pdev, info) {
Ok(data) => {
// Let the `struct pci_dev` own a reference of the driver's private data.
// SAFETY: By the type invariant `pdev.as_raw` returns a valid pointer to a
@@ -192,7 +195,7 @@ macro_rules! pci_device_table {
/// # Example
///
///```
-/// # use kernel::{bindings, pci};
+/// # use kernel::{bindings, device::Core, pci};
///
/// struct MyDriver;
///
@@ -210,7 +213,7 @@ macro_rules! pci_device_table {
/// const ID_TABLE: pci::IdTable<Self::IdInfo> = &PCI_TABLE;
///
/// fn probe(
-/// _pdev: &mut pci::Device,
+/// _pdev: &pci::Device<Core>,
/// _id_info: &Self::IdInfo,
/// ) -> Result<Pin<KBox<Self>>> {
/// Err(ENODEV)
@@ -234,20 +237,23 @@ pub trait Driver {
///
/// Called when a new platform device is added or discovered.
/// Implementers should attempt to initialize the device here.
- fn probe(dev: &mut Device, id_info: &Self::IdInfo) -> Result<Pin<KBox<Self>>>;
+ fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> Result<Pin<KBox<Self>>>;
}

/// The PCI device representation.
///
-/// A PCI device is based on an always reference counted `device:Device` instance. Cloning a PCI
-/// device, hence, also increments the base device' reference count.
+/// This structure represents the Rust abstraction for a C `struct pci_dev`. The implementation
+/// abstracts the usage of an already existing C `struct pci_dev` within Rust code that we get
+/// passed from the C side.
///
/// # Invariants
///
-/// `Device` hold a valid reference of `ARef<device::Device>` whose underlying `struct device` is a
-/// member of a `struct pci_dev`.
-#[derive(Clone)]
-pub struct Device(ARef<device::Device>);
+/// A [`Device`] instance represents a valid `struct device` created by the C portion of the kernel.
+#[repr(transparent)]
+pub struct Device<Ctx: device::DeviceContext = device::Normal>(
+ Opaque<bindings::pci_dev>,
+ PhantomData<Ctx>,
+);

/// A PCI BAR to perform I/O-Operations on.
///
@@ -256,13 +262,13 @@ pub trait Driver {
/// `Bar` always holds an `IoRaw` inststance that holds a valid pointer to the start of the I/O
/// memory mapped PCI bar and its size.
pub struct Bar<const SIZE: usize = 0> {
- pdev: Device,
+ pdev: ARef<Device>,
io: IoRaw<SIZE>,
num: i32,
}

impl<const SIZE: usize> Bar<SIZE> {
- fn new(pdev: Device, num: u32, name: &CStr) -> Result<Self> {
+ fn new(pdev: &Device, num: u32, name: &CStr) -> Result<Self> {
let len = pdev.resource_len(num)?;
if len == 0 {
return Err(ENOMEM);
@@ -300,12 +306,16 @@ fn new(pdev: Device, num: u32, name: &CStr) -> Result<Self> {
// `pdev` is valid by the invariants of `Device`.
// `ioptr` is guaranteed to be the start of a valid I/O mapped memory region.
// `num` is checked for validity by a previous call to `Device::resource_len`.
- unsafe { Self::do_release(&pdev, ioptr, num) };
+ unsafe { Self::do_release(pdev, ioptr, num) };
return Err(err);
}
};

- Ok(Bar { pdev, io, num })
+ Ok(Bar {
+ pdev: pdev.into(),
+ io,
+ num,
+ })
}

/// # Safety
@@ -351,20 +361,8 @@ fn deref(&self) -> &Self::Target {
}

impl Device {
- /// Create a PCI Device instance from an existing `device::Device`.
- ///
- /// # Safety
- ///
- /// `dev` must be an `ARef<device::Device>` whose underlying `bindings::device` is a member of
- /// a `bindings::pci_dev`.
- pub unsafe fn from_dev(dev: ARef<device::Device>) -> Self {
- Self(dev)
- }
-
fn as_raw(&self) -> *mut bindings::pci_dev {
- // SAFETY: By the type invariant `self.0.as_raw` is a pointer to the `struct device`
- // embedded in `struct pci_dev`.
- unsafe { container_of!(self.0.as_raw(), bindings::pci_dev, dev) as _ }
+ self.0.get()
}

/// Returns the PCI vendor ID.
@@ -379,18 +377,6 @@ pub fn device_id(&self) -> u16 {
unsafe { (*self.as_raw()).device }
}

- /// Enable memory resources for this device.
- pub fn enable_device_mem(&self) -> Result {
- // SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`.
- to_result(unsafe { bindings::pci_enable_device_mem(self.as_raw()) })
- }
-
- /// Enable bus-mastering for this device.
- pub fn set_master(&self) {
- // SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`.
- unsafe { bindings::pci_set_master(self.as_raw()) };
- }
-
/// Returns the size of the given PCI bar resource.
pub fn resource_len(&self, bar: u32) -> Result<bindings::resource_size_t> {
if !Bar::index_is_valid(bar) {
@@ -410,7 +396,7 @@ pub fn iomap_region_sized<const SIZE: usize>(
bar: u32,
name: &CStr,
) -> Result<Devres<Bar<SIZE>>> {
- let bar = Bar::<SIZE>::new(self.clone(), bar, name)?;
+ let bar = Bar::<SIZE>::new(self, bar, name)?;
let devres = Devres::new(self.as_ref(), bar, GFP_KERNEL)?;

Ok(devres)
@@ -422,8 +408,54 @@ pub fn iomap_region(&self, bar: u32, name: &CStr) -> Result<Devres<Bar>> {
}
}

+impl Device<device::Core> {
+ /// Enable memory resources for this device.
+ pub fn enable_device_mem(&self) -> Result {
+ // SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`.
+ to_result(unsafe { bindings::pci_enable_device_mem(self.as_raw()) })
+ }
+
+ /// Enable bus-mastering for this device.
+ pub fn set_master(&self) {
+ // SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`.
+ unsafe { bindings::pci_set_master(self.as_raw()) };
+ }
+}
+
+impl Deref for Device<device::Core> {
+ type Target = Device;
+
+ fn deref(&self) -> &Self::Target {
+ let ptr: *const Self = self;
+
+ // CAST: `Device<Ctx>` is a transparent wrapper of `Opaque<bindings::pci_dev>`.
+ let ptr = ptr.cast::<Device>();
+
+ // SAFETY: `ptr` was derived from `&self`.
+ unsafe { &*ptr }
+ }
+}
+
+// SAFETY: Instances of `Device` are always reference-counted.
+unsafe impl crate::types::AlwaysRefCounted for Device {
+ fn inc_ref(&self) {
+ // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero.
+ unsafe { bindings::pci_dev_get(self.as_raw()) };
+ }
+
+ unsafe fn dec_ref(obj: NonNull<Self>) {
+ // SAFETY: The safety requirements guarantee that the refcount is non-zero.
+ unsafe { bindings::pci_dev_put(obj.cast().as_ptr()) }
+ }
+}
+
impl AsRef<device::Device> for Device {
fn as_ref(&self) -> &device::Device {
- &self.0
+ // SAFETY: By the type invariant of `Self`, `self.as_raw()` is a pointer to a valid
+ // `struct pci_dev`.
+ let dev = unsafe { addr_of_mut!((*self.as_raw()).dev) };
+
+ // SAFETY: `dev` points to a valid `struct device`.
+ unsafe { device::Device::as_ref(dev) }
}
}
diff --git a/samples/rust/rust_driver_pci.rs b/samples/rust/rust_driver_pci.rs
index 1fb6e44f3395..b90df5f9d1d0 100644
--- a/samples/rust/rust_driver_pci.rs
+++ b/samples/rust/rust_driver_pci.rs
@@ -4,7 +4,7 @@
//!
//! To make this driver probe, QEMU must be run with `-device pci-testdev`.

-use kernel::{bindings, c_str, devres::Devres, pci, prelude::*};
+use kernel::{bindings, c_str, device::Core, devres::Devres, pci, prelude::*, types::ARef};

struct Regs;

@@ -26,7 +26,7 @@ impl TestIndex {
}

struct SampleDriver {
- pdev: pci::Device,
+ pdev: ARef<pci::Device>,
bar: Devres<Bar0>,
}

@@ -62,7 +62,7 @@ impl pci::Driver for SampleDriver {

const ID_TABLE: pci::IdTable<Self::IdInfo> = &PCI_TABLE;

- fn probe(pdev: &mut pci::Device, info: &Self::IdInfo) -> Result<Pin<KBox<Self>>> {
+ fn probe(pdev: &pci::Device<Core>, info: &Self::IdInfo) -> Result<Pin<KBox<Self>>> {
dev_dbg!(
pdev.as_ref(),
"Probe Rust PCI driver sample (PCI ID: 0x{:x}, 0x{:x}).\n",
@@ -77,7 +77,7 @@ fn probe(pdev: &mut pci::Device, info: &Self::IdInfo) -> Result<Pin<KBox<Self>>>

let drvdata = KBox::new(
Self {
- pdev: pdev.clone(),
+ pdev: (&**pdev).into(),
bar,
},
GFP_KERNEL,
--
2.48.1