|
1 | | -use std::{ |
2 | | - collections::{hash_map::Entry, HashMap}, |
3 | | - sync::Arc, |
4 | | -}; |
| 1 | +use std::sync::Arc; |
5 | 2 |
|
6 | 3 | use futures::{stream, FutureExt, StreamExt, TryStreamExt}; |
7 | 4 | use futures_channel::mpsc::UnboundedReceiver; |
8 | 5 | use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; |
9 | 6 | use postgres_openssl::MakeTlsConnector; |
10 | | -use pyo3::{pyclass, pymethods, Py, PyAny, PyErr, PyObject, Python}; |
11 | | -use pyo3_async_runtimes::TaskLocals; |
| 7 | +use pyo3::{pyclass, pymethods, Py, PyAny, PyErr, Python}; |
12 | 8 | use tokio::{sync::RwLock, task::{AbortHandle, JoinHandle}}; |
13 | | -use tokio_postgres::{AsyncMessage, Config, Notification}; |
| 9 | +use tokio_postgres::{AsyncMessage, Config}; |
14 | 10 |
|
15 | 11 | use crate::{ |
16 | | - driver::utils::is_coroutine_function, exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, runtime::{rustdriver_future, tokio_runtime} |
17 | | -}; |
18 | | - |
19 | | -use super::{ |
20 | | - common_options::SslMode, connection::{Connection, InnerConnection}, utils::{build_tls, ConfiguredTLS} |
| 12 | + driver::{common_options::SslMode, connection::{Connection, InnerConnection}, utils::{build_tls, is_coroutine_function, ConfiguredTLS}}, exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, runtime::{rustdriver_future, tokio_runtime} |
21 | 13 | }; |
22 | 14 |
|
23 | | -struct ChannelCallbacks(HashMap<String, Vec<ListenerCallback>>); |
24 | | - |
25 | | -impl Default for ChannelCallbacks { |
26 | | - fn default() -> Self { |
27 | | - ChannelCallbacks(Default::default()) |
28 | | - } |
29 | | -} |
30 | | - |
31 | | -impl ChannelCallbacks { |
32 | | - fn add_callback(&mut self, channel: String, callback: ListenerCallback) { |
33 | | - match self.0.entry(channel) { |
34 | | - Entry::Vacant(e) => { |
35 | | - e.insert(vec![callback]); |
36 | | - } |
37 | | - Entry::Occupied(mut e) => { |
38 | | - e.get_mut().push(callback); |
39 | | - } |
40 | | - }; |
41 | | - } |
42 | | - |
43 | | - fn retrieve_channel_callbacks(&self, channel: String) -> Option<&Vec<ListenerCallback>> { |
44 | | - self.0.get(&channel) |
45 | | - } |
46 | | - |
47 | | - fn clear_channel_callbacks(&mut self, channel: String) { |
48 | | - self.0.remove(&channel); |
49 | | - } |
50 | | - |
51 | | - fn retrieve_all_channels(&self) -> Vec<&String> { |
52 | | - self.0.keys().collect::<Vec<&String>>() |
53 | | - } |
54 | | -} |
55 | | - |
56 | | - |
57 | | -#[derive(Clone, Debug)] |
58 | | -pub struct ListenerNotification { |
59 | | - pub process_id: i32, |
60 | | - pub channel: String, |
61 | | - pub payload: String, |
62 | | -} |
63 | | - |
64 | | -impl From::<Notification> for ListenerNotification { |
65 | | - fn from(value: Notification) -> Self { |
66 | | - ListenerNotification { |
67 | | - process_id: value.process_id(), |
68 | | - channel: String::from(value.channel()), |
69 | | - payload: String::from(value.payload()), |
70 | | - } |
71 | | - } |
72 | | -} |
73 | | - |
74 | | -#[pyclass] |
75 | | -pub struct ListenerNotificationMsg { |
76 | | - process_id: i32, |
77 | | - channel: String, |
78 | | - payload: String, |
79 | | - connection: Connection, |
80 | | -} |
| 15 | +use super::structs::{ChannelCallbacks, ListenerCallback, ListenerNotification, ListenerNotificationMsg}; |
81 | 16 |
|
82 | | -#[pymethods] |
83 | | -impl ListenerNotificationMsg { |
84 | | - #[getter] |
85 | | - fn process_id(&self) -> i32 { |
86 | | - self.process_id |
87 | | - } |
88 | | - |
89 | | - #[getter] |
90 | | - fn channel(&self) -> String { |
91 | | - self.channel.clone() |
92 | | - } |
93 | | - |
94 | | - #[getter] |
95 | | - fn payload(&self) -> String { |
96 | | - self.payload.clone() |
97 | | - } |
98 | | - |
99 | | - #[getter] |
100 | | - fn connection(&self) -> Connection { |
101 | | - self.connection.clone() |
102 | | - } |
103 | | -} |
104 | | - |
105 | | -impl ListenerNotificationMsg { |
106 | | - fn new(value: ListenerNotification, conn: Connection) -> Self { |
107 | | - ListenerNotificationMsg { |
108 | | - process_id: value.process_id, |
109 | | - channel: String::from(value.channel), |
110 | | - payload: String::from(value.payload), |
111 | | - connection: conn, |
112 | | - } |
113 | | - } |
114 | | -} |
115 | | - |
116 | | -struct ListenerCallback { |
117 | | - task_locals: Option<TaskLocals>, |
118 | | - callback: Py<PyAny>, |
119 | | -} |
120 | | - |
121 | | -impl ListenerCallback { |
122 | | - pub fn new( |
123 | | - task_locals: Option<TaskLocals>, |
124 | | - callback: Py<PyAny>, |
125 | | - ) -> Self { |
126 | | - ListenerCallback { |
127 | | - task_locals, |
128 | | - callback, |
129 | | - } |
130 | | - } |
131 | | - |
132 | | - async fn call( |
133 | | - &self, |
134 | | - lister_notification: ListenerNotification, |
135 | | - connection: Connection, |
136 | | - ) -> RustPSQLDriverPyResult<()> { |
137 | | - let (callback, task_locals) = Python::with_gil(|py| { |
138 | | - if let Some(task_locals) = &self.task_locals { |
139 | | - return (self.callback.clone(), Some(task_locals.clone_ref(py))); |
140 | | - } |
141 | | - (self.callback.clone(), None) |
142 | | - }); |
143 | | - |
144 | | - if let Some(task_locals) = task_locals { |
145 | | - tokio_runtime().spawn(pyo3_async_runtimes::tokio::scope(task_locals, async move { |
146 | | - let future = Python::with_gil(|py| { |
147 | | - let awaitable = callback.call1( |
148 | | - py, |
149 | | - ( |
150 | | - lister_notification.channel, |
151 | | - lister_notification.payload, |
152 | | - lister_notification.process_id, |
153 | | - connection, |
154 | | - ) |
155 | | - ).unwrap(); |
156 | | - pyo3_async_runtimes::tokio::into_future(awaitable.into_bound(py)).unwrap() |
157 | | - }); |
158 | | - future.await.unwrap(); |
159 | | - })).await?; |
160 | | - }; |
161 | | - |
162 | | - Ok(()) |
163 | | - } |
164 | | -} |
165 | 17 |
|
166 | 18 | #[pyclass] |
167 | 19 | pub struct Listener { |
|
0 commit comments