#![allow(missing_docs)]
use std::cell::Cell;
use std::marker::PhantomData;
use crate::impl_::pyclass::{PyClassBaseType, PyClassImpl};
use crate::PyCell;
use super::{PyBorrowError, PyBorrowMutError};
pub trait PyClassMutability {
type Storage: PyClassBorrowChecker;
type Checker: PyClassBorrowChecker;
type ImmutableChild: PyClassMutability;
type MutableChild: PyClassMutability;
}
pub struct ImmutableClass(());
pub struct MutableClass(());
pub struct ExtendsMutableAncestor<M: PyClassMutability>(PhantomData<M>);
impl PyClassMutability for ImmutableClass {
type Storage = EmptySlot;
type Checker = EmptySlot;
type ImmutableChild = ImmutableClass;
type MutableChild = MutableClass;
}
impl PyClassMutability for MutableClass {
type Storage = BorrowChecker;
type Checker = BorrowChecker;
type ImmutableChild = ExtendsMutableAncestor<ImmutableClass>;
type MutableChild = ExtendsMutableAncestor<MutableClass>;
}
impl<M: PyClassMutability> PyClassMutability for ExtendsMutableAncestor<M> {
type Storage = EmptySlot;
type Checker = BorrowChecker;
type ImmutableChild = ExtendsMutableAncestor<ImmutableClass>;
type MutableChild = ExtendsMutableAncestor<MutableClass>;
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
struct BorrowFlag(usize);
impl BorrowFlag {
pub(crate) const UNUSED: BorrowFlag = BorrowFlag(0);
const HAS_MUTABLE_BORROW: BorrowFlag = BorrowFlag(usize::max_value());
const fn increment(self) -> Self {
Self(self.0 + 1)
}
const fn decrement(self) -> Self {
Self(self.0 - 1)
}
}
pub struct EmptySlot(());
pub struct BorrowChecker(Cell<BorrowFlag>);
pub trait PyClassBorrowChecker {
fn new() -> Self;
fn try_borrow(&self) -> Result<(), PyBorrowError>;
fn try_borrow_unguarded(&self) -> Result<(), PyBorrowError>;
fn release_borrow(&self);
fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError>;
fn release_borrow_mut(&self);
}
impl PyClassBorrowChecker for EmptySlot {
#[inline]
fn new() -> Self {
EmptySlot(())
}
#[inline]
fn try_borrow(&self) -> Result<(), PyBorrowError> {
Ok(())
}
#[inline]
fn try_borrow_unguarded(&self) -> Result<(), PyBorrowError> {
Ok(())
}
#[inline]
fn release_borrow(&self) {}
#[inline]
fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError> {
unreachable!()
}
#[inline]
fn release_borrow_mut(&self) {
unreachable!()
}
}
impl PyClassBorrowChecker for BorrowChecker {
#[inline]
fn new() -> Self {
Self(Cell::new(BorrowFlag::UNUSED))
}
fn try_borrow(&self) -> Result<(), PyBorrowError> {
let flag = self.0.get();
if flag != BorrowFlag::HAS_MUTABLE_BORROW {
self.0.set(flag.increment());
Ok(())
} else {
Err(PyBorrowError { _private: () })
}
}
fn try_borrow_unguarded(&self) -> Result<(), PyBorrowError> {
let flag = self.0.get();
if flag != BorrowFlag::HAS_MUTABLE_BORROW {
Ok(())
} else {
Err(PyBorrowError { _private: () })
}
}
fn release_borrow(&self) {
let flag = self.0.get();
self.0.set(flag.decrement())
}
fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError> {
let flag = self.0.get();
if flag == BorrowFlag::UNUSED {
self.0.set(BorrowFlag::HAS_MUTABLE_BORROW);
Ok(())
} else {
Err(PyBorrowMutError { _private: () })
}
}
fn release_borrow_mut(&self) {
self.0.set(BorrowFlag::UNUSED)
}
}
pub trait GetBorrowChecker<T: PyClassImpl> {
fn borrow_checker(cell: &PyCell<T>) -> &<T::PyClassMutability as PyClassMutability>::Checker;
}
impl<T: PyClassImpl<PyClassMutability = Self>> GetBorrowChecker<T> for MutableClass {
fn borrow_checker(cell: &PyCell<T>) -> &BorrowChecker {
&cell.contents.borrow_checker
}
}
impl<T: PyClassImpl<PyClassMutability = Self>> GetBorrowChecker<T> for ImmutableClass {
fn borrow_checker(cell: &PyCell<T>) -> &EmptySlot {
&cell.contents.borrow_checker
}
}
impl<T: PyClassImpl<PyClassMutability = Self>, M: PyClassMutability> GetBorrowChecker<T>
for ExtendsMutableAncestor<M>
where
T::BaseType: PyClassImpl<Layout = PyCell<T::BaseType>>
+ PyClassBaseType<LayoutAsBase = PyCell<T::BaseType>>,
<T::BaseType as PyClassImpl>::PyClassMutability: PyClassMutability<Checker = BorrowChecker>,
{
fn borrow_checker(cell: &PyCell<T>) -> &BorrowChecker {
<<T::BaseType as PyClassImpl>::PyClassMutability as GetBorrowChecker<T::BaseType>>::borrow_checker(&cell.ob_base)
}
}
#[cfg(test)]
#[cfg(feature = "macros")]
mod tests {
use super::*;
use crate::impl_::pyclass::{PyClassBaseType, PyClassImpl};
use crate::prelude::*;
use crate::pyclass::boolean_struct::{False, True};
use crate::PyClass;
#[pyclass(crate = "crate", subclass)]
struct MutableBase;
#[pyclass(crate = "crate", extends = MutableBase, subclass)]
struct MutableChildOfMutableBase;
#[pyclass(crate = "crate", extends = MutableBase, frozen, subclass)]
struct ImmutableChildOfMutableBase;
#[pyclass(crate = "crate", extends = MutableChildOfMutableBase)]
struct MutableChildOfMutableChildOfMutableBase;
#[pyclass(crate = "crate", extends = ImmutableChildOfMutableBase)]
struct MutableChildOfImmutableChildOfMutableBase;
#[pyclass(crate = "crate", extends = MutableChildOfMutableBase, frozen)]
struct ImmutableChildOfMutableChildOfMutableBase;
#[pyclass(crate = "crate", extends = ImmutableChildOfMutableBase, frozen)]
struct ImmutableChildOfImmutableChildOfMutableBase;
#[pyclass(crate = "crate", frozen, subclass)]
struct ImmutableBase;
#[pyclass(crate = "crate", extends = ImmutableBase, subclass)]
struct MutableChildOfImmutableBase;
#[pyclass(crate = "crate", extends = ImmutableBase, frozen, subclass)]
struct ImmutableChildOfImmutableBase;
#[pyclass(crate = "crate", extends = MutableChildOfImmutableBase)]
struct MutableChildOfMutableChildOfImmutableBase;
#[pyclass(crate = "crate", extends = ImmutableChildOfImmutableBase)]
struct MutableChildOfImmutableChildOfImmutableBase;
#[pyclass(crate = "crate", extends = MutableChildOfImmutableBase, frozen)]
struct ImmutableChildOfMutableChildOfImmutableBase;
#[pyclass(crate = "crate", extends = ImmutableChildOfImmutableBase, frozen)]
struct ImmutableChildOfImmutableChildOfImmutableBase;
fn assert_mutable<T: PyClass<Frozen = False, PyClassMutability = MutableClass>>() {}
fn assert_immutable<T: PyClass<Frozen = True, PyClassMutability = ImmutableClass>>() {}
fn assert_mutable_with_mutable_ancestor<
T: PyClass<Frozen = False, PyClassMutability = ExtendsMutableAncestor<MutableClass>>,
>()
where
<T as PyClassImpl>::BaseType: PyClassImpl<Layout = PyCell<T::BaseType>>,
<<T as PyClassImpl>::BaseType as PyClassImpl>::PyClassMutability:
PyClassMutability<Checker = BorrowChecker>,
<T as PyClassImpl>::BaseType: PyClassBaseType<LayoutAsBase = PyCell<T::BaseType>>,
{
}
fn assert_immutable_with_mutable_ancestor<
T: PyClass<Frozen = True, PyClassMutability = ExtendsMutableAncestor<ImmutableClass>>,
>()
where
<T as PyClassImpl>::BaseType: PyClassImpl<Layout = PyCell<T::BaseType>>,
<<T as PyClassImpl>::BaseType as PyClassImpl>::PyClassMutability:
PyClassMutability<Checker = BorrowChecker>,
<T as PyClassImpl>::BaseType: PyClassBaseType<LayoutAsBase = PyCell<T::BaseType>>,
{
}
#[test]
fn test_inherited_mutability() {
assert_mutable::<MutableBase>();
assert_mutable_with_mutable_ancestor::<MutableChildOfMutableBase>();
assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableBase>();
assert_mutable_with_mutable_ancestor::<MutableChildOfMutableChildOfMutableBase>();
assert_mutable_with_mutable_ancestor::<MutableChildOfImmutableChildOfMutableBase>();
assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableChildOfMutableBase>();
assert_immutable_with_mutable_ancestor::<ImmutableChildOfImmutableChildOfMutableBase>();
assert_immutable::<ImmutableBase>();
assert_immutable::<ImmutableChildOfImmutableBase>();
assert_immutable::<ImmutableChildOfImmutableChildOfImmutableBase>();
assert_mutable::<MutableChildOfImmutableBase>();
assert_mutable::<MutableChildOfImmutableChildOfImmutableBase>();
assert_mutable_with_mutable_ancestor::<MutableChildOfMutableChildOfImmutableBase>();
assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableChildOfImmutableBase>();
}
#[test]
fn test_mutable_borrow_prevents_further_borrows() {
Python::with_gil(|py| {
let mmm = Py::new(
py,
PyClassInitializer::from(MutableBase)
.add_subclass(MutableChildOfMutableBase)
.add_subclass(MutableChildOfMutableChildOfMutableBase),
)
.unwrap();
let mmm_cell: &PyCell<MutableChildOfMutableChildOfMutableBase> = mmm.as_ref(py);
let mmm_refmut = mmm_cell.borrow_mut();
assert!(mmm_cell
.extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>()
.is_err());
assert!(mmm_cell
.extract::<PyRef<'_, MutableChildOfMutableBase>>()
.is_err());
assert!(mmm_cell.extract::<PyRef<'_, MutableBase>>().is_err());
assert!(mmm_cell
.extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
.is_err());
assert!(mmm_cell
.extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
.is_err());
assert!(mmm_cell.extract::<PyRefMut<'_, MutableBase>>().is_err());
drop(mmm_refmut);
assert!(mmm_cell
.extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>()
.is_ok());
assert!(mmm_cell
.extract::<PyRef<'_, MutableChildOfMutableBase>>()
.is_ok());
assert!(mmm_cell.extract::<PyRef<'_, MutableBase>>().is_ok());
assert!(mmm_cell
.extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
.is_ok());
assert!(mmm_cell
.extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
.is_ok());
assert!(mmm_cell.extract::<PyRefMut<'_, MutableBase>>().is_ok());
})
}
#[test]
fn test_immutable_borrows_prevent_mutable_borrows() {
Python::with_gil(|py| {
let mmm = Py::new(
py,
PyClassInitializer::from(MutableBase)
.add_subclass(MutableChildOfMutableBase)
.add_subclass(MutableChildOfMutableChildOfMutableBase),
)
.unwrap();
let mmm_cell: &PyCell<MutableChildOfMutableChildOfMutableBase> = mmm.as_ref(py);
let mmm_refmut = mmm_cell.borrow();
assert!(mmm_cell
.extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>()
.is_ok());
assert!(mmm_cell
.extract::<PyRef<'_, MutableChildOfMutableBase>>()
.is_ok());
assert!(mmm_cell.extract::<PyRef<'_, MutableBase>>().is_ok());
assert!(mmm_cell
.extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
.is_err());
assert!(mmm_cell
.extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
.is_err());
assert!(mmm_cell.extract::<PyRefMut<'_, MutableBase>>().is_err());
drop(mmm_refmut);
assert!(mmm_cell
.extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
.is_ok());
assert!(mmm_cell
.extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
.is_ok());
assert!(mmm_cell.extract::<PyRefMut<'_, MutableBase>>().is_ok());
})
}
}