1use std::error::Error;
2use std::io::{BufReader, BufWriter, Read, Write};
3use std::sync::{mpsc, Arc, Mutex};
4use std::thread;
5use std::thread::JoinHandle;
6use std::time::{Duration, Instant};
7
8use super::handler::{self, DefaultHandler, Handler, RequestHandler};
9use rmpv::Value;
10
11use super::model;
12
13type Callback = Box<FnMut(Result<Value, Value>) + Send + 'static>;
14type Queue = Arc<Mutex<Vec<(u64, Sender)>>>;
15
16enum Sender {
17 Sync(mpsc::Sender<Result<Value, Value>>),
18 Async(Callback),
19}
20
21impl Sender {
22 fn send(self, res: Result<Value, Value>) {
23 match self {
24 Sender::Sync(sender) => sender.send(res).unwrap(),
25 Sender::Async(mut cb) => cb(res),
26 };
27 }
28}
29
30pub struct Client<R, W>
31where
32 R: Read + Send + 'static,
33 W: Write + Send + 'static,
34{
35 reader: Option<BufReader<R>>,
36 writer: Arc<Mutex<BufWriter<W>>>,
37 dispatch_guard: Option<JoinHandle<()>>,
38 event_loop_started: bool,
39 queue: Queue,
40 msgid_counter: u64,
41}
42
43impl<R, W> Client<R, W>
44where
45 R: Read + Send + 'static,
46 W: Write + Send + 'static,
47{
48 pub fn take_dispatch_guard(&mut self) -> JoinHandle<()> {
49 self.dispatch_guard
50 .take()
51 .expect("Can only take join handle after running event loop")
52 }
53
54 pub fn start_event_loop_channel_handler<H>(
55 &mut self,
56 request_handler: H,
57 ) -> mpsc::Receiver<(String, Vec<Value>)>
58 where
59 H: RequestHandler + Send + 'static,
60 {
61 let (handler, reciever) = handler::channel(request_handler);
62
63 self.dispatch_guard = Some(Self::dispatch_thread(
64 self.queue.clone(),
65 self.reader.take().unwrap(),
66 self.writer.clone(),
67 handler,
68 ));
69 self.event_loop_started = true;
70
71 reciever
72 }
73
74 pub fn start_event_loop_handler<H>(&mut self, handler: H)
75 where
76 H: Handler + Send + 'static,
77 {
78 self.dispatch_guard = Some(Self::dispatch_thread(
79 self.queue.clone(),
80 self.reader.take().unwrap(),
81 self.writer.clone(),
82 handler,
83 ));
84 self.event_loop_started = true;
85 }
86
87 pub fn start_event_loop(&mut self) {
88 self.dispatch_guard = Some(Self::dispatch_thread(
89 self.queue.clone(),
90 self.reader.take().unwrap(),
91 self.writer.clone(),
92 DefaultHandler(),
93 ));
94 self.event_loop_started = true;
95 }
96
97 pub fn new(reader: R, writer: W) -> Self {
98 let queue = Arc::new(Mutex::new(Vec::new()));
99 Client {
100 reader: Some(BufReader::new(reader)),
101 writer: Arc::new(Mutex::new(BufWriter::new(writer))),
102 msgid_counter: 0,
103 queue: queue.clone(),
104 dispatch_guard: None,
105 event_loop_started: false,
106 }
107 }
108
109 pub fn call_async(&mut self, method: String, args: Vec<Value>, cb: Option<Callback>) {
110 if !self.event_loop_started {
111 if let Some(mut cb) = cb {
112 cb(Err(Value::from("Event loop not started")));
113 } else {
114 error!("Event loop not started");
115 }
116 return;
117 }
118
119 self.send_msg_async(method, args, cb);
120 }
121
122 pub fn call_timeout(
123 &mut self,
124 method: &str,
125 args: Vec<Value>,
126 dur: Duration,
127 ) -> Result<Value, Value> {
128 if !self.event_loop_started {
129 return Err(Value::from("Event loop not started"));
130 }
131
132 let instant = Instant::now();
133 let delay = Duration::from_millis(1);
134
135 let receiver = self.send_msg(method, args);
136
137 loop {
138 match receiver.try_recv() {
139 Err(mpsc::TryRecvError::Empty) => {
140 thread::sleep(delay);
141 if instant.elapsed() >= dur {
142 return Err(Value::from(format!("Wait timeout ({})", method)));
143 }
144 }
145 Err(mpsc::TryRecvError::Disconnected) => {
146 return Err(Value::from(format!("Channel disconnected ({})", method)))
147 }
148 Ok(val) => return val,
149 };
150 }
151 }
152
153 fn send_msg_async(&mut self, method: String, params: Vec<Value>, cb: Option<Callback>) {
154 let msgid = self.msgid_counter;
155 self.msgid_counter += 1;
156
157 let req = model::RpcMessage::RpcRequest {
158 msgid,
159 method,
160 params,
161 };
162
163 if let Some(cb) = cb {
164 self.queue.lock().unwrap().push((msgid, Sender::Async(cb)));
165 }
166
167 let writer = &mut *self.writer.lock().unwrap();
168 model::encode(writer, req).expect("Error sending message");
169 }
170
171 fn send_msg(&mut self, method: &str, args: Vec<Value>) -> mpsc::Receiver<Result<Value, Value>> {
172 let msgid = self.msgid_counter;
173 self.msgid_counter += 1;
174
175 let req = model::RpcMessage::RpcRequest {
176 msgid,
177 method: method.to_owned(),
178 params: args,
179 };
180
181 let (sender, receiver) = mpsc::channel();
182 self.queue
183 .lock()
184 .unwrap()
185 .push((msgid, Sender::Sync(sender)));
186
187 let writer = &mut *self.writer.lock().unwrap();
188 model::encode(writer, req).expect("Error sending message");
189
190 receiver
191 }
192
193 pub fn call(
194 &mut self,
195 method: &str,
196 args: Vec<Value>,
197 dur: Option<Duration>,
198 ) -> Result<Value, Value> {
199 match dur {
200 Some(dur) => self.call_timeout(method, args, dur),
201 None => self.call_inf(method, args),
202 }
203 }
204
205 pub fn call_inf(&mut self, method: &str, args: Vec<Value>) -> Result<Value, Value> {
206 if !self.event_loop_started {
207 return Err(Value::from("Event loop not started"));
208 }
209
210 let receiver = self.send_msg(method, args);
211
212 receiver.recv().unwrap()
213 }
214
215 fn send_error_to_callers(queue: &Queue, err: &Box<Error>) {
216 let mut queue = queue.lock().unwrap();
217 queue.drain(0..).for_each(|sender| {
218 sender
219 .1
220 .send(Err(Value::from(format!("Error read response: {}", err))))
221 });
222 }
223
224 fn dispatch_thread<H>(
225 queue: Queue,
226 mut reader: BufReader<R>,
227 writer: Arc<Mutex<BufWriter<W>>>,
228 mut handler: H,
229 ) -> JoinHandle<()>
230 where
231 H: Handler + Send + 'static,
232 {
233 thread::spawn(move || loop {
234 let msg = match model::decode(&mut reader) {
235 Ok(msg) => msg,
236 Err(e) => {
237 error!("Error while reading: {}", e);
238 Self::send_error_to_callers(&queue, &e);
239 return;
240 }
241 };
242 debug!("Get message {:?}", msg);
243 match msg {
244 model::RpcMessage::RpcRequest {
245 msgid,
246 method,
247 params,
248 } => {
249 let response = match handler.handle_request(&method, params) {
250 Ok(result) => model::RpcMessage::RpcResponse {
251 msgid,
252 result,
253 error: Value::Nil,
254 },
255 Err(error) => model::RpcMessage::RpcResponse {
256 msgid,
257 result: Value::Nil,
258 error,
259 },
260 };
261
262 let writer = &mut *writer.lock().unwrap();
263 model::encode(writer, response).expect("Error sending RPC response");
264 }
265 model::RpcMessage::RpcResponse {
266 msgid,
267 result,
268 error,
269 } => {
270 let sender = find_sender(&queue, msgid);
271 if error != Value::Nil {
272 sender.send(Err(error));
273 } else {
274 sender.send(Ok(result));
275 }
276 }
277 model::RpcMessage::RpcNotification { method, params } => {
278 handler.handle_notify(&method, params);
279 }
280 };
281 })
282 }
283}
284
285fn find_sender(queue: &Queue, msgid: u64) -> Sender {
290 let mut queue = queue.lock().unwrap();
291
292 let pos = queue.iter().position(|req| req.0 == msgid).unwrap();
293 queue.remove(pos).1
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn test_find_sender() {
302 let queue = Arc::new(Mutex::new(Vec::new()));
303
304 {
305 let (sender, _receiver) = mpsc::channel();
306 queue.lock().unwrap().push((1, Sender::Sync(sender)));
307 }
308 {
309 let (sender, _receiver) = mpsc::channel();
310 queue.lock().unwrap().push((2, Sender::Sync(sender)));
311 }
312 {
313 let (sender, _receiver) = mpsc::channel();
314 queue.lock().unwrap().push((3, Sender::Sync(sender)));
315 }
316
317 find_sender(&queue, 1);
318 assert_eq!(2, queue.lock().unwrap().len());
319 find_sender(&queue, 2);
320 assert_eq!(1, queue.lock().unwrap().len());
321 find_sender(&queue, 3);
322 assert!(queue.lock().unwrap().is_empty());
323 }
324}