fragile/
fragile.rs

1use std::cell::UnsafeCell;
2use std::cmp;
3use std::fmt;
4use std::mem;
5use std::mem::ManuallyDrop;
6use std::sync::atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT};
7
8use errors::InvalidThreadAccess;
9
10fn next_thread_id() -> usize {
11    static mut COUNTER: AtomicUsize = ATOMIC_USIZE_INIT;
12    unsafe { COUNTER.fetch_add(1, Ordering::SeqCst) }
13}
14
15pub(crate) fn get_thread_id() -> usize {
16    thread_local!(static THREAD_ID: usize = next_thread_id());
17    THREAD_ID.with(|&x| x)
18}
19
20/// A `Fragile<T>` wraps a non sendable `T` to be safely send to other threads.
21///
22/// Once the value has been wrapped it can be sent to other threads but access
23/// to the value on those threads will fail.
24///
25/// If the value needs destruction and the fragile wrapper is on another thread
26/// the destructor will panic.  Alternatively you can use `Sticky<T>` which is
27/// not going to panic but might temporarily leak the value.
28pub struct Fragile<T> {
29    value: ManuallyDrop<UnsafeCell<Box<T>>>,
30    thread_id: usize,
31}
32
33impl<T> Fragile<T> {
34    /// Creates a new `Fragile` wrapping a `value`.
35    ///
36    /// The value that is moved into the `Fragile` can be non `Send` and
37    /// will be anchored to the thread that created the object.  If the
38    /// fragile wrapper type ends up being send from thread to thread
39    /// only the original thread can interact with the value.
40    pub fn new(value: T) -> Self {
41        Fragile {
42            value: ManuallyDrop::new(UnsafeCell::new(Box::new(value))),
43            thread_id: get_thread_id(),
44        }
45    }
46
47    /// Returns `true` if the access is valid.
48    ///
49    /// This will be `false` if the value was sent to another thread.
50    pub fn is_valid(&self) -> bool {
51        get_thread_id() == self.thread_id
52    }
53
54    #[inline(always)]
55    fn assert_thread(&self) {
56        if !self.is_valid() {
57            panic!("trying to access wrapped value in fragile container from incorrect thread.");
58        }
59    }
60
61    /// Consumes the `Fragile`, returning the wrapped value.
62    ///
63    /// # Panics
64    ///
65    /// Panics if called from a different thread than the one where the
66    /// original value was created.
67    pub fn into_inner(mut self) -> T {
68        self.assert_thread();
69        unsafe {
70            let value = mem::replace(&mut self.value, mem::uninitialized());
71            mem::forget(self);
72            *ManuallyDrop::into_inner(value).into_inner()
73        }
74    }
75
76    /// Consumes the `Fragile`, returning the wrapped value if successful.
77    ///
78    /// The wrapped value is returned if this is called from the same thread
79    /// as the one where the original value was created, otherwise the
80    /// `Fragile` is returned as `Err(self)`.
81    pub fn try_into_inner(self) -> Result<T, Self> {
82        if get_thread_id() == self.thread_id {
83            Ok(self.into_inner())
84        } else {
85            Err(self)
86        }
87    }
88
89    /// Immutably borrows the wrapped value.
90    ///
91    /// # Panics
92    ///
93    /// Panics if the calling thread is not the one that wrapped the value.
94    /// For a non-panicking variant, use [`try_get`](#method.try_get`).
95    pub fn get(&self) -> &T {
96        self.assert_thread();
97        unsafe { &*self.value.get() }
98    }
99
100    /// Mutably borrows the wrapped value.
101    ///
102    /// # Panics
103    ///
104    /// Panics if the calling thread is not the one that wrapped the value.
105    /// For a non-panicking variant, use [`try_get_mut`](#method.try_get_mut`).
106    pub fn get_mut(&mut self) -> &mut T {
107        self.assert_thread();
108        unsafe { &mut *self.value.get() }
109    }
110
111    /// Tries to immutably borrow the wrapped value.
112    ///
113    /// Returns `None` if the calling thread is not the one that wrapped the value.
114    pub fn try_get(&self) -> Result<&T, InvalidThreadAccess> {
115        if get_thread_id() == self.thread_id {
116            unsafe { Ok(&*self.value.get()) }
117        } else {
118            Err(InvalidThreadAccess)
119        }
120    }
121
122    /// Tries to mutably borrow the wrapped value.
123    ///
124    /// Returns `None` if the calling thread is not the one that wrapped the value.
125    pub fn try_get_mut(&mut self) -> Result<&mut T, InvalidThreadAccess> {
126        if get_thread_id() == self.thread_id {
127            unsafe { Ok(&mut *self.value.get()) }
128        } else {
129            Err(InvalidThreadAccess)
130        }
131    }
132}
133
134impl<T> Drop for Fragile<T> {
135    fn drop(&mut self) {
136        if mem::needs_drop::<T>() {
137            if get_thread_id() == self.thread_id {
138                unsafe { ManuallyDrop::drop(&mut self.value) }
139            } else {
140                panic!("destructor of fragile object ran on wrong thread");
141            }
142        }
143    }
144}
145
146impl<T> From<T> for Fragile<T> {
147    #[inline]
148    fn from(t: T) -> Fragile<T> {
149        Fragile::new(t)
150    }
151}
152
153impl<T: Clone> Clone for Fragile<T> {
154    #[inline]
155    fn clone(&self) -> Fragile<T> {
156        Fragile::new(self.get().clone())
157    }
158}
159
160impl<T: Default> Default for Fragile<T> {
161    #[inline]
162    fn default() -> Fragile<T> {
163        Fragile::new(T::default())
164    }
165}
166
167impl<T: PartialEq> PartialEq for Fragile<T> {
168    #[inline]
169    fn eq(&self, other: &Fragile<T>) -> bool {
170        *self.get() == *other.get()
171    }
172}
173
174impl<T: Eq> Eq for Fragile<T> {}
175
176impl<T: PartialOrd> PartialOrd for Fragile<T> {
177    #[inline]
178    fn partial_cmp(&self, other: &Fragile<T>) -> Option<cmp::Ordering> {
179        self.get().partial_cmp(&*other.get())
180    }
181
182    #[inline]
183    fn lt(&self, other: &Fragile<T>) -> bool {
184        *self.get() < *other.get()
185    }
186
187    #[inline]
188    fn le(&self, other: &Fragile<T>) -> bool {
189        *self.get() <= *other.get()
190    }
191
192    #[inline]
193    fn gt(&self, other: &Fragile<T>) -> bool {
194        *self.get() > *other.get()
195    }
196
197    #[inline]
198    fn ge(&self, other: &Fragile<T>) -> bool {
199        *self.get() >= *other.get()
200    }
201}
202
203impl<T: Ord> Ord for Fragile<T> {
204    #[inline]
205    fn cmp(&self, other: &Fragile<T>) -> cmp::Ordering {
206        self.get().cmp(&*other.get())
207    }
208}
209
210impl<T: fmt::Display> fmt::Display for Fragile<T> {
211    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
212        fmt::Display::fmt(self.get(), f)
213    }
214}
215
216impl<T: fmt::Debug> fmt::Debug for Fragile<T> {
217    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
218        match self.try_get() {
219            Ok(value) => f.debug_struct("Fragile").field("value", value).finish(),
220            Err(..) => {
221                struct InvalidPlaceholder;
222                impl fmt::Debug for InvalidPlaceholder {
223                    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
224                        f.write_str("<invalid thread>")
225                    }
226                }
227
228                f.debug_struct("Fragile")
229                    .field("value", &InvalidPlaceholder)
230                    .finish()
231            }
232        }
233    }
234}
235
236// this type is sync because access can only ever happy from the same thread
237// that created it originally.  All other threads will be able to safely
238// call some basic operations on the reference and they will fail.
239unsafe impl<T> Sync for Fragile<T> {}
240
241// The entire point of this type is to be Send
242unsafe impl<T> Send for Fragile<T> {}
243
244#[test]
245fn test_basic() {
246    use std::thread;
247    let val = Fragile::new(true);
248    assert_eq!(val.to_string(), "true");
249    assert_eq!(val.get(), &true);
250    assert!(val.try_get().is_ok());
251    thread::spawn(move || {
252        assert!(val.try_get().is_err());
253    }).join()
254        .unwrap();
255}
256
257#[test]
258fn test_mut() {
259    let mut val = Fragile::new(true);
260    *val.get_mut() = false;
261    assert_eq!(val.to_string(), "false");
262    assert_eq!(val.get(), &false);
263}
264
265#[test]
266#[should_panic]
267fn test_access_other_thread() {
268    use std::thread;
269    let val = Fragile::new(true);
270    thread::spawn(move || {
271        val.get();
272    }).join()
273        .unwrap();
274}
275
276#[test]
277fn test_noop_drop_elsewhere() {
278    use std::thread;
279    let val = Fragile::new(true);
280    thread::spawn(move || {
281        // force the move
282        val.try_get().ok();
283    }).join()
284        .unwrap();
285}
286
287#[test]
288fn test_panic_on_drop_elsewhere() {
289    use std::sync::atomic::{AtomicBool, Ordering};
290    use std::sync::Arc;
291    use std::thread;
292    let was_called = Arc::new(AtomicBool::new(false));
293    struct X(Arc<AtomicBool>);
294    impl Drop for X {
295        fn drop(&mut self) {
296            self.0.store(true, Ordering::SeqCst);
297        }
298    }
299    let val = Fragile::new(X(was_called.clone()));
300    assert!(
301        thread::spawn(move || {
302            val.try_get().ok();
303        }).join()
304            .is_err()
305    );
306    assert_eq!(was_called.load(Ordering::SeqCst), false);
307}
308
309#[test]
310fn test_rc_sending() {
311    use std::rc::Rc;
312    use std::thread;
313    use std::sync::mpsc::channel;
314
315    let val = Fragile::new(Rc::new(true));
316    let (tx, rx) = channel();
317
318    let thread = thread::spawn(move || {
319        assert!(val.try_get().is_err());
320        let here = val;
321        tx.send(here).unwrap();
322    });
323
324    let rv = rx.recv().unwrap();
325    assert!(**rv.get());
326
327    thread.join().unwrap();
328}