neovim_lib/rpc/
client.rs

1use std::error::Error;
2use std::io::{BufReader, BufWriter, Read, Write};
3use std::sync::{mpsc, Arc, Mutex};
4use std::thread;
5use std::thread::JoinHandle;
6use std::time::{Duration, Instant};
7
8use super::handler::{self, DefaultHandler, Handler, RequestHandler};
9use rmpv::Value;
10
11use super::model;
12
13type Callback = Box<FnMut(Result<Value, Value>) + Send + 'static>;
14type Queue = Arc<Mutex<Vec<(u64, Sender)>>>;
15
16enum Sender {
17    Sync(mpsc::Sender<Result<Value, Value>>),
18    Async(Callback),
19}
20
21impl Sender {
22    fn send(self, res: Result<Value, Value>) {
23        match self {
24            Sender::Sync(sender) => sender.send(res).unwrap(),
25            Sender::Async(mut cb) => cb(res),
26        };
27    }
28}
29
30pub struct Client<R, W>
31where
32    R: Read + Send + 'static,
33    W: Write + Send + 'static,
34{
35    reader: Option<BufReader<R>>,
36    writer: Arc<Mutex<BufWriter<W>>>,
37    dispatch_guard: Option<JoinHandle<()>>,
38    event_loop_started: bool,
39    queue: Queue,
40    msgid_counter: u64,
41}
42
43impl<R, W> Client<R, W>
44where
45    R: Read + Send + 'static,
46    W: Write + Send + 'static,
47{
48    pub fn take_dispatch_guard(&mut self) -> JoinHandle<()> {
49        self.dispatch_guard
50            .take()
51            .expect("Can only take join handle after running event loop")
52    }
53
54    pub fn start_event_loop_channel_handler<H>(
55        &mut self,
56        request_handler: H,
57    ) -> mpsc::Receiver<(String, Vec<Value>)>
58    where
59        H: RequestHandler + Send + 'static,
60    {
61        let (handler, reciever) = handler::channel(request_handler);
62
63        self.dispatch_guard = Some(Self::dispatch_thread(
64            self.queue.clone(),
65            self.reader.take().unwrap(),
66            self.writer.clone(),
67            handler,
68        ));
69        self.event_loop_started = true;
70
71        reciever
72    }
73
74    pub fn start_event_loop_handler<H>(&mut self, handler: H)
75    where
76        H: Handler + Send + 'static,
77    {
78        self.dispatch_guard = Some(Self::dispatch_thread(
79            self.queue.clone(),
80            self.reader.take().unwrap(),
81            self.writer.clone(),
82            handler,
83        ));
84        self.event_loop_started = true;
85    }
86
87    pub fn start_event_loop(&mut self) {
88        self.dispatch_guard = Some(Self::dispatch_thread(
89            self.queue.clone(),
90            self.reader.take().unwrap(),
91            self.writer.clone(),
92            DefaultHandler(),
93        ));
94        self.event_loop_started = true;
95    }
96
97    pub fn new(reader: R, writer: W) -> Self {
98        let queue = Arc::new(Mutex::new(Vec::new()));
99        Client {
100            reader: Some(BufReader::new(reader)),
101            writer: Arc::new(Mutex::new(BufWriter::new(writer))),
102            msgid_counter: 0,
103            queue: queue.clone(),
104            dispatch_guard: None,
105            event_loop_started: false,
106        }
107    }
108
109    pub fn call_async(&mut self, method: String, args: Vec<Value>, cb: Option<Callback>) {
110        if !self.event_loop_started {
111            if let Some(mut cb) = cb {
112                cb(Err(Value::from("Event loop not started")));
113            } else {
114                error!("Event loop not started");
115            }
116            return;
117        }
118
119        self.send_msg_async(method, args, cb);
120    }
121
122    pub fn call_timeout(
123        &mut self,
124        method: &str,
125        args: Vec<Value>,
126        dur: Duration,
127    ) -> Result<Value, Value> {
128        if !self.event_loop_started {
129            return Err(Value::from("Event loop not started"));
130        }
131
132        let instant = Instant::now();
133        let delay = Duration::from_millis(1);
134
135        let receiver = self.send_msg(method, args);
136
137        loop {
138            match receiver.try_recv() {
139                Err(mpsc::TryRecvError::Empty) => {
140                    thread::sleep(delay);
141                    if instant.elapsed() >= dur {
142                        return Err(Value::from(format!("Wait timeout ({})", method)));
143                    }
144                }
145                Err(mpsc::TryRecvError::Disconnected) => {
146                    return Err(Value::from(format!("Channel disconnected ({})", method)))
147                }
148                Ok(val) => return val,
149            };
150        }
151    }
152
153    fn send_msg_async(&mut self, method: String, params: Vec<Value>, cb: Option<Callback>) {
154        let msgid = self.msgid_counter;
155        self.msgid_counter += 1;
156
157        let req = model::RpcMessage::RpcRequest {
158            msgid,
159            method,
160            params,
161        };
162
163        if let Some(cb) = cb {
164            self.queue.lock().unwrap().push((msgid, Sender::Async(cb)));
165        }
166
167        let writer = &mut *self.writer.lock().unwrap();
168        model::encode(writer, req).expect("Error sending message");
169    }
170
171    fn send_msg(&mut self, method: &str, args: Vec<Value>) -> mpsc::Receiver<Result<Value, Value>> {
172        let msgid = self.msgid_counter;
173        self.msgid_counter += 1;
174
175        let req = model::RpcMessage::RpcRequest {
176            msgid,
177            method: method.to_owned(),
178            params: args,
179        };
180
181        let (sender, receiver) = mpsc::channel();
182        self.queue
183            .lock()
184            .unwrap()
185            .push((msgid, Sender::Sync(sender)));
186
187        let writer = &mut *self.writer.lock().unwrap();
188        model::encode(writer, req).expect("Error sending message");
189
190        receiver
191    }
192
193    pub fn call(
194        &mut self,
195        method: &str,
196        args: Vec<Value>,
197        dur: Option<Duration>,
198    ) -> Result<Value, Value> {
199        match dur {
200            Some(dur) => self.call_timeout(method, args, dur),
201            None => self.call_inf(method, args),
202        }
203    }
204
205    pub fn call_inf(&mut self, method: &str, args: Vec<Value>) -> Result<Value, Value> {
206        if !self.event_loop_started {
207            return Err(Value::from("Event loop not started"));
208        }
209
210        let receiver = self.send_msg(method, args);
211
212        receiver.recv().unwrap()
213    }
214
215    fn send_error_to_callers(queue: &Queue, err: &Box<Error>) {
216        let mut queue = queue.lock().unwrap();
217        queue.drain(0..).for_each(|sender| {
218            sender
219                .1
220                .send(Err(Value::from(format!("Error read response: {}", err))))
221        });
222    }
223
224    fn dispatch_thread<H>(
225        queue: Queue,
226        mut reader: BufReader<R>,
227        writer: Arc<Mutex<BufWriter<W>>>,
228        mut handler: H,
229    ) -> JoinHandle<()>
230    where
231        H: Handler + Send + 'static,
232    {
233        thread::spawn(move || loop {
234            let msg = match model::decode(&mut reader) {
235                Ok(msg) => msg,
236                Err(e) => {
237                    error!("Error while reading: {}", e);
238                    Self::send_error_to_callers(&queue, &e);
239                    return;
240                }
241            };
242            debug!("Get message {:?}", msg);
243            match msg {
244                model::RpcMessage::RpcRequest {
245                    msgid,
246                    method,
247                    params,
248                } => {
249                    let response = match handler.handle_request(&method, params) {
250                        Ok(result) => model::RpcMessage::RpcResponse {
251                            msgid,
252                            result,
253                            error: Value::Nil,
254                        },
255                        Err(error) => model::RpcMessage::RpcResponse {
256                            msgid,
257                            result: Value::Nil,
258                            error,
259                        },
260                    };
261
262                    let writer = &mut *writer.lock().unwrap();
263                    model::encode(writer, response).expect("Error sending RPC response");
264                }
265                model::RpcMessage::RpcResponse {
266                    msgid,
267                    result,
268                    error,
269                } => {
270                    let sender = find_sender(&queue, msgid);
271                    if error != Value::Nil {
272                        sender.send(Err(error));
273                    } else {
274                        sender.send(Ok(result));
275                    }
276                }
277                model::RpcMessage::RpcNotification { method, params } => {
278                    handler.handle_notify(&method, params);
279                }
280            };
281        })
282    }
283}
284
285/* The idea to use Vec here instead of HashMap
286 * is that Vec is faster on small queue sizes
287 * in most cases Vec.len = 1 so we just take first item in iteration.
288 */
289fn find_sender(queue: &Queue, msgid: u64) -> Sender {
290    let mut queue = queue.lock().unwrap();
291
292    let pos = queue.iter().position(|req| req.0 == msgid).unwrap();
293    queue.remove(pos).1
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    #[test]
301    fn test_find_sender() {
302        let queue = Arc::new(Mutex::new(Vec::new()));
303
304        {
305            let (sender, _receiver) = mpsc::channel();
306            queue.lock().unwrap().push((1, Sender::Sync(sender)));
307        }
308        {
309            let (sender, _receiver) = mpsc::channel();
310            queue.lock().unwrap().push((2, Sender::Sync(sender)));
311        }
312        {
313            let (sender, _receiver) = mpsc::channel();
314            queue.lock().unwrap().push((3, Sender::Sync(sender)));
315        }
316
317        find_sender(&queue, 1);
318        assert_eq!(2, queue.lock().unwrap().len());
319        find_sender(&queue, 2);
320        assert_eq!(1, queue.lock().unwrap().len());
321        find_sender(&queue, 3);
322        assert!(queue.lock().unwrap().is_empty());
323    }
324}