use std::{
cell::UnsafeCell,
sync::atomic::{self, AtomicBool},
};
use crate::{exceptions::PyImportError, ffi, types::PyModule, Py, PyResult, Python};
pub struct ModuleDef {
ffi_def: UnsafeCell<ffi::PyModuleDef>,
initializer: ModuleInitializer,
initialized: AtomicBool,
}
pub struct ModuleInitializer(pub for<'py> fn(Python<'py>, &PyModule) -> PyResult<()>);
unsafe impl Sync for ModuleDef {}
impl ModuleDef {
pub const unsafe fn new(
name: &'static str,
doc: &'static str,
initializer: ModuleInitializer,
) -> Self {
const INIT: ffi::PyModuleDef = ffi::PyModuleDef {
m_base: ffi::PyModuleDef_HEAD_INIT,
m_name: std::ptr::null(),
m_doc: std::ptr::null(),
m_size: 0,
m_methods: std::ptr::null_mut(),
m_slots: std::ptr::null_mut(),
m_traverse: None,
m_clear: None,
m_free: None,
};
let ffi_def = UnsafeCell::new(ffi::PyModuleDef {
m_name: name.as_ptr() as *const _,
m_doc: doc.as_ptr() as *const _,
..INIT
});
ModuleDef {
ffi_def,
initializer,
initialized: AtomicBool::new(false),
}
}
pub fn make_module(&'static self, py: Python<'_>) -> PyResult<Py<PyModule>> {
#[cfg(all(PyPy, not(Py_3_8)))]
{
const PYPY_GOOD_VERSION: [u8; 3] = [7, 3, 8];
let version = py
.import("sys")?
.getattr("implementation")?
.getattr("version")?;
if version.lt(crate::types::PyTuple::new(py, &PYPY_GOOD_VERSION))? {
let warn = py.import("warnings")?.getattr("warn")?;
warn.call1((
"PyPy 3.7 versions older than 7.3.8 are known to have binary \
compatibility issues which may cause segfaults. Please upgrade.",
))?;
}
}
let module = unsafe {
Py::<PyModule>::from_owned_ptr_or_err(py, ffi::PyModule_Create(self.ffi_def.get()))?
};
if self.initialized.swap(true, atomic::Ordering::SeqCst) {
return Err(PyImportError::new_err(
"PyO3 modules may only be initialized once per interpreter process",
));
}
(self.initializer.0)(py, module.as_ref(py))?;
Ok(module)
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicBool, Ordering};
use crate::{types::PyModule, PyResult, Python};
use super::{ModuleDef, ModuleInitializer};
#[test]
fn module_init() {
static MODULE_DEF: ModuleDef = unsafe {
ModuleDef::new(
"test_module\0",
"some doc\0",
ModuleInitializer(|_, m| {
m.add("SOME_CONSTANT", 42)?;
Ok(())
}),
)
};
Python::with_gil(|py| {
let module = MODULE_DEF.make_module(py).unwrap().into_ref(py);
assert_eq!(
module
.getattr("__name__")
.unwrap()
.extract::<&str>()
.unwrap(),
"test_module",
);
assert_eq!(
module
.getattr("__doc__")
.unwrap()
.extract::<&str>()
.unwrap(),
"some doc",
);
assert_eq!(
module
.getattr("SOME_CONSTANT")
.unwrap()
.extract::<u8>()
.unwrap(),
42,
);
})
}
#[test]
fn module_def_new() {
static NAME: &str = "test_module\0";
static DOC: &str = "some doc\0";
static INIT_CALLED: AtomicBool = AtomicBool::new(false);
#[allow(clippy::unnecessary_wraps)]
fn init(_: Python<'_>, _: &PyModule) -> PyResult<()> {
INIT_CALLED.store(true, Ordering::SeqCst);
Ok(())
}
unsafe {
let module_def: ModuleDef = ModuleDef::new(NAME, DOC, ModuleInitializer(init));
assert_eq!((*module_def.ffi_def.get()).m_name, NAME.as_ptr() as _);
assert_eq!((*module_def.ffi_def.get()).m_doc, DOC.as_ptr() as _);
Python::with_gil(|py| {
module_def.initializer.0(py, py.import("builtins").unwrap()).unwrap();
assert!(INIT_CALLED.load(Ordering::SeqCst));
})
}
}
}