diff --git a/rust/aa-ffi-python/Cargo.toml b/rust/aa-ffi-python/Cargo.toml index b716d9f..1eb1bf3 100644 --- a/rust/aa-ffi-python/Cargo.toml +++ b/rust/aa-ffi-python/Cargo.toml @@ -14,8 +14,7 @@ aa-core = { git = "https://github.com/AI-agent-assembly/agent-assembly.git", rev aa-proto = { git = "https://github.com/AI-agent-assembly/agent-assembly.git", rev = "ed4aa11a8c1d1ce1e6f96b08cf2179fd772099b2", package = "aa-proto" } once_cell = "1.20" prost = "0.14" -pyo3 = { version = "0.20", features = ["extension-module"] } -pyo3-asyncio = { version = "0.20", features = ["tokio-runtime"] } +pyo3 = { version = "0.28", features = ["extension-module"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" tokio = { version = "1.41", features = ["io-util", "net", "rt-multi-thread", "sync", "time"] } diff --git a/rust/aa-ffi-python/src/lib.rs b/rust/aa-ffi-python/src/lib.rs index 4cbf47f..698b832 100644 --- a/rust/aa-ffi-python/src/lib.rs +++ b/rust/aa-ffi-python/src/lib.rs @@ -13,7 +13,7 @@ use prost::Message; use pyo3::exceptions::PyValueError; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; -use pyo3::types::PyDict; +use pyo3::types::{PyAny, PyBytes, PyDict, PyList, PyModule}; use std::sync::Arc; use std::sync::Mutex; use std::sync::atomic::{AtomicBool, Ordering}; @@ -170,7 +170,7 @@ impl RuntimeClient { Ok(()) } - fn query_policy(&self, py: Python<'_>, action: &PyAny) -> PyResult { + fn query_policy(&self, py: Python<'_>, action: &Bound<'_, PyAny>) -> PyResult { ensure_client_open(self.closed.as_ref(), self.last_error.as_ref())?; let action_json = serialize_action_to_json(py, action)?; let timeout_ms = extract_timeout_ms(action); @@ -188,7 +188,7 @@ impl RuntimeClient { }) .map_err(|_| PyRuntimeError::new_err("failed to enqueue policy query"))?; - let worker_result = py.allow_threads(|| wait_for_worker_response(timeout_ms + 100, response_rx)); + let worker_result = py.detach(|| wait_for_worker_response(timeout_ms + 100, response_rx)); let worker_result = worker_result.map_err(|error| match error { WorkerWaitError::Timeout => PolicyTimeoutError::new_err("policy query timed out"), WorkerWaitError::Disconnected => PyRuntimeError::new_err("policy worker disconnected"), @@ -376,16 +376,16 @@ fn ensure_client_open(closed: &AtomicBool, last_error: &Mutex>) - Err(PyRuntimeError::new_err("runtime client is closed")) } -fn extract_timeout_ms(action: &PyAny) -> u64 { +fn extract_timeout_ms(action: &Bound<'_, PyAny>) -> u64 { action - .downcast::() + .cast::() .ok() .and_then(|dict| dict.get_item("timeout_ms").ok().flatten()) .and_then(|value| value.extract::().ok()) .unwrap_or(50) } -fn serialize_action_to_json(py: Python<'_>, action: &PyAny) -> PyResult { +fn serialize_action_to_json(py: Python<'_>, action: &Bound<'_, PyAny>) -> PyResult { let json_module = PyModule::import(py, "json")?; let dumped = json_module.call_method1("dumps", (action,))?; dumped.extract::() @@ -440,7 +440,7 @@ fn decision_from_str(value: &str) -> i32 { } } -fn audit_event_from_py(event: &PyAny) -> PyResult { +fn audit_event_from_py(event: &Bound<'_, PyAny>) -> PyResult { let event_id = event.getattr("event_id")?.extract::()?; let agent_id_str = event.getattr("agent_id")?.extract::()?; let action_type_str = event.getattr("action_type")?.extract::()?; @@ -453,8 +453,8 @@ fn audit_event_from_py(event: &PyAny) -> PyResult { .extract::>()?; let call_stack_py = event.getattr("call_stack")?; let mut call_stack = Vec::new(); - for node in call_stack_py.iter()? { - call_stack.push(call_stack_node_from_py(node?)?); + for node in call_stack_py.try_iter()? { + call_stack.push(call_stack_node_from_py(&node?)?); } Ok(AuditEvent { event_id, @@ -474,7 +474,7 @@ fn audit_event_from_py(event: &PyAny) -> PyResult { }) } -fn audit_event_to_py(py: Python<'_>, event: &AuditEvent) -> PyResult { +fn audit_event_to_py(py: Python<'_>, event: &AuditEvent) -> PyResult> { let types_module = PyModule::import(py, "agent_assembly.types")?; let cls = types_module.getattr("AuditEvent")?; let kwargs = PyDict::new(py); @@ -495,15 +495,15 @@ fn audit_event_to_py(py: Python<'_>, event: &AuditEvent) -> PyResult { labels.set_item(k, v)?; } kwargs.set_item("labels", labels)?; - let call_stack = pyo3::types::PyList::empty(py); + let call_stack = PyList::empty(py); for node in &event.call_stack { call_stack.append(call_stack_node_to_py(py, node)?)?; } kwargs.set_item("call_stack", call_stack)?; - Ok(cls.call((), Some(kwargs))?.into()) + Ok(cls.call((), Some(&kwargs))?.unbind()) } -fn call_stack_node_from_py(node: &PyAny) -> PyResult { +fn call_stack_node_from_py(node: &Bound<'_, PyAny>) -> PyResult { let id = node.getattr("id")?.extract::()?; let kind = node.getattr("kind")?.extract::()?; let label = node.getattr("label")?.extract::()?; @@ -513,8 +513,8 @@ fn call_stack_node_from_py(node: &PyAny) -> PyResult { .unwrap_or(0); let children_py = node.getattr("children")?; let mut children = Vec::new(); - for child in children_py.iter()? { - children.push(call_stack_node_from_py(child?)?); + for child in children_py.try_iter()? { + children.push(call_stack_node_from_py(&child?)?); } Ok(ProtoCallStackNode { id, @@ -525,7 +525,7 @@ fn call_stack_node_from_py(node: &PyAny) -> PyResult { }) } -fn call_stack_node_to_py(py: Python<'_>, node: &ProtoCallStackNode) -> PyResult { +fn call_stack_node_to_py(py: Python<'_>, node: &ProtoCallStackNode) -> PyResult> { let types_module = PyModule::import(py, "agent_assembly.types")?; let cls = types_module.getattr("CallStackNode")?; let kwargs = PyDict::new(py); @@ -538,12 +538,12 @@ fn call_stack_node_to_py(py: Python<'_>, node: &ProtoCallStackNode) -> PyResult< Some(node.latency_ms) }; kwargs.set_item("latency_ms", latency)?; - let children = pyo3::types::PyList::empty(py); + let children = PyList::empty(py); for child in &node.children { children.append(call_stack_node_to_py(py, child)?)?; } kwargs.set_item("children", children)?; - Ok(cls.call((), Some(kwargs))?.into()) + Ok(cls.call((), Some(&kwargs))?.unbind()) } fn decision_to_str(value: i32) -> &'static str { @@ -694,21 +694,21 @@ fn wait_for_worker_response( } #[pyfunction] -fn audit_event_to_wire_bytes(py: Python<'_>, event: &PyAny) -> PyResult { +fn audit_event_to_wire_bytes(py: Python<'_>, event: &Bound<'_, PyAny>) -> PyResult> { let proto = audit_event_from_py(event)?; let encoded = proto.encode_to_vec(); - Ok(pyo3::types::PyBytes::new(py, &encoded).into()) + Ok(PyBytes::new(py, &encoded).into_any().unbind()) } #[pyfunction] -fn audit_event_from_wire_bytes(py: Python<'_>, data: &pyo3::types::PyBytes) -> PyResult { +fn audit_event_from_wire_bytes(py: Python<'_>, data: &Bound<'_, PyBytes>) -> PyResult> { let proto = AuditEvent::decode(data.as_bytes()) .map_err(|error| PyValueError::new_err(format!("failed to decode AuditEvent wire bytes: {error}")))?; audit_event_to_py(py, &proto) } #[pymodule] -fn _core(py: Python<'_>, module: &PyModule) -> PyResult<()> { +fn _core(py: Python<'_>, module: &Bound<'_, PyModule>) -> PyResult<()> { module.add("PolicyTimeoutError", py.get_type::())?; module.add_class::()?; module.add_class::()?;