fragile/
sticky.rs

1use std::cell::UnsafeCell;
2use std::cmp;
3use std::collections::HashMap;
4use std::fmt;
5use std::marker::PhantomData;
6use std::mem;
7use std::sync::atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT};
8
9use errors::InvalidThreadAccess;
10
11fn next_item_id() -> usize {
12    static mut COUNTER: AtomicUsize = ATOMIC_USIZE_INIT;
13    unsafe { COUNTER.fetch_add(1, Ordering::SeqCst) }
14}
15
16struct Registry(HashMap<usize, (UnsafeCell<*mut ()>, Box<Fn(&UnsafeCell<*mut ()>)>)>);
17
18impl Drop for Registry {
19    fn drop(&mut self) {
20        for (_, value) in self.0.iter() {
21            (value.1)(&value.0);
22        }
23    }
24}
25
26thread_local!(static REGISTRY: UnsafeCell<Registry> = UnsafeCell::new(Registry(Default::default())));
27
28/// A `Sticky<T>` keeps a value T stored in a thread.
29///
30/// This type works similar in nature to `Fragile<T>` and exposes the
31/// same interface.  The difference is that whereas `Fragile<T>` has
32/// its destructor called in the thread where the value was sent, a
33/// `Sticky<T>` that is moved to another thread will have the internal
34/// destructor called when the originating thread tears down.
35///
36/// As this uses TLS internally the general rules about the platform limitations
37/// of destructors for TLS apply.
38pub struct Sticky<T> {
39    item_id: usize,
40    _marker: PhantomData<*mut T>,
41}
42
43impl<T> Drop for Sticky<T> {
44    fn drop(&mut self) {
45        if mem::needs_drop::<T>() {
46            unsafe {
47                if self.is_valid() {
48                    self.unsafe_take_value();
49                }
50            }
51        }
52    }
53}
54
55impl<T> Sticky<T> {
56    /// Creates a new `Sticky` wrapping a `value`.
57    ///
58    /// The value that is moved into the `Sticky` can be non `Send` and
59    /// will be anchored to the thread that created the object.  If the
60    /// sticky wrapper type ends up being send from thread to thread
61    /// only the original thread can interact with the value.
62    pub fn new(value: T) -> Self {
63        let item_id = next_item_id();
64        REGISTRY.with(|registry| unsafe {
65            (*registry.get()).0.insert(
66                item_id,
67                (
68                    UnsafeCell::new(Box::into_raw(Box::new(value)) as *mut _),
69                    Box::new(|cell| {
70                        let b: Box<T> = Box::from_raw(*(cell.get() as *mut *mut T));
71                        mem::drop(b);
72                    }),
73                ),
74            );
75        });
76        Sticky {
77            item_id: item_id,
78            _marker: PhantomData,
79        }
80    }
81
82    #[inline(always)]
83    fn with_value<F: FnOnce(&UnsafeCell<Box<T>>) -> R, R>(&self, f: F) -> R {
84        REGISTRY.with(|registry| unsafe {
85            let reg = &(*(*registry).get()).0;
86            if let Some(item) = reg.get(&self.item_id) {
87                f(mem::transmute(&item.0))
88            } else {
89                panic!("trying to access wrapped value in sticky container from incorrect thread.");
90            }
91        })
92    }
93
94    /// Returns `true` if the access is valid.
95    ///
96    /// This will be `false` if the value was sent to another thread.
97    #[inline(always)]
98    pub fn is_valid(&self) -> bool {
99        // We use `try-with` here to avoid crashing if the TLS is already tearing down.
100        unsafe { REGISTRY.try_with(|registry| (*registry.get()).0.contains_key(&self.item_id)).unwrap_or(false) }
101    }
102
103    #[inline(always)]
104    fn assert_thread(&self) {
105        if !self.is_valid() {
106            panic!("trying to access wrapped value in sticky container from incorrect thread.");
107        }
108    }
109
110    /// Consumes the `Sticky`, returning the wrapped value.
111    ///
112    /// # Panics
113    ///
114    /// Panics if called from a different thread than the one where the
115    /// original value was created.
116    pub fn into_inner(mut self) -> T {
117        self.assert_thread();
118        unsafe {
119            let rv = self.unsafe_take_value();
120            mem::forget(self);
121            rv
122        }
123    }
124
125    unsafe fn unsafe_take_value(&mut self) -> T {
126        let ptr = REGISTRY
127            .with(|registry| (*registry.get()).0.remove(&self.item_id))
128            .unwrap()
129            .0
130            .into_inner();
131        let rv = Box::from_raw(ptr as *mut T);
132        *rv
133    }
134
135    /// Consumes the `Sticky`, returning the wrapped value if successful.
136    ///
137    /// The wrapped value is returned if this is called from the same thread
138    /// as the one where the original value was created, otherwise the
139    /// `Sticky` is returned as `Err(self)`.
140    pub fn try_into_inner(self) -> Result<T, Self> {
141        if self.is_valid() {
142            Ok(self.into_inner())
143        } else {
144            Err(self)
145        }
146    }
147
148    /// Immutably borrows the wrapped value.
149    ///
150    /// # Panics
151    ///
152    /// Panics if the calling thread is not the one that wrapped the value.
153    /// For a non-panicking variant, use [`try_get`](#method.try_get`).
154    pub fn get(&self) -> &T {
155        self.with_value(|value| unsafe { &*value.get() })
156    }
157
158    /// Mutably borrows the wrapped value.
159    ///
160    /// # Panics
161    ///
162    /// Panics if the calling thread is not the one that wrapped the value.
163    /// For a non-panicking variant, use [`try_get_mut`](#method.try_get_mut`).
164    pub fn get_mut(&mut self) -> &mut T {
165        self.with_value(|value| unsafe { &mut *value.get() })
166    }
167
168    /// Tries to immutably borrow the wrapped value.
169    ///
170    /// Returns `None` if the calling thread is not the one that wrapped the value.
171    pub fn try_get(&self) -> Result<&T, InvalidThreadAccess> {
172        if self.is_valid() {
173            unsafe { Ok(self.with_value(|value| &*value.get())) }
174        } else {
175            Err(InvalidThreadAccess)
176        }
177    }
178
179    /// Tries to mutably borrow the wrapped value.
180    ///
181    /// Returns `None` if the calling thread is not the one that wrapped the value.
182    pub fn try_get_mut(&mut self) -> Result<&mut T, InvalidThreadAccess> {
183        if self.is_valid() {
184            unsafe { Ok(self.with_value(|value| &mut *value.get())) }
185        } else {
186            Err(InvalidThreadAccess)
187        }
188    }
189}
190
191impl<T> From<T> for Sticky<T> {
192    #[inline]
193    fn from(t: T) -> Sticky<T> {
194        Sticky::new(t)
195    }
196}
197
198impl<T: Clone> Clone for Sticky<T> {
199    #[inline]
200    fn clone(&self) -> Sticky<T> {
201        Sticky::new(self.get().clone())
202    }
203}
204
205impl<T: Default> Default for Sticky<T> {
206    #[inline]
207    fn default() -> Sticky<T> {
208        Sticky::new(T::default())
209    }
210}
211
212impl<T: PartialEq> PartialEq for Sticky<T> {
213    #[inline]
214    fn eq(&self, other: &Sticky<T>) -> bool {
215        *self.get() == *other.get()
216    }
217}
218
219impl<T: Eq> Eq for Sticky<T> {}
220
221impl<T: PartialOrd> PartialOrd for Sticky<T> {
222    #[inline]
223    fn partial_cmp(&self, other: &Sticky<T>) -> Option<cmp::Ordering> {
224        self.get().partial_cmp(&*other.get())
225    }
226
227    #[inline]
228    fn lt(&self, other: &Sticky<T>) -> bool {
229        *self.get() < *other.get()
230    }
231
232    #[inline]
233    fn le(&self, other: &Sticky<T>) -> bool {
234        *self.get() <= *other.get()
235    }
236
237    #[inline]
238    fn gt(&self, other: &Sticky<T>) -> bool {
239        *self.get() > *other.get()
240    }
241
242    #[inline]
243    fn ge(&self, other: &Sticky<T>) -> bool {
244        *self.get() >= *other.get()
245    }
246}
247
248impl<T: Ord> Ord for Sticky<T> {
249    #[inline]
250    fn cmp(&self, other: &Sticky<T>) -> cmp::Ordering {
251        self.get().cmp(&*other.get())
252    }
253}
254
255impl<T: fmt::Display> fmt::Display for Sticky<T> {
256    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
257        fmt::Display::fmt(self.get(), f)
258    }
259}
260
261impl<T: fmt::Debug> fmt::Debug for Sticky<T> {
262    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
263        match self.try_get() {
264            Ok(value) => f.debug_struct("Sticky").field("value", value).finish(),
265            Err(..) => {
266                struct InvalidPlaceholder;
267                impl fmt::Debug for InvalidPlaceholder {
268                    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
269                        f.write_str("<invalid thread>")
270                    }
271                }
272
273                f.debug_struct("Sticky")
274                    .field("value", &InvalidPlaceholder)
275                    .finish()
276            }
277        }
278    }
279}
280
281// similar as for fragile ths type is sync because it only accesses TLS data
282// which is thread local.  There is nothing that needs to be synchronized.
283unsafe impl<T> Sync for Sticky<T> {}
284
285// The entire point of this type is to be Send
286unsafe impl<T> Send for Sticky<T> {}
287
288#[test]
289fn test_basic() {
290    use std::thread;
291    let val = Sticky::new(true);
292    assert_eq!(val.to_string(), "true");
293    assert_eq!(val.get(), &true);
294    assert!(val.try_get().is_ok());
295    thread::spawn(move || {
296        assert!(val.try_get().is_err());
297    }).join()
298        .unwrap();
299}
300
301#[test]
302fn test_mut() {
303    let mut val = Sticky::new(true);
304    *val.get_mut() = false;
305    assert_eq!(val.to_string(), "false");
306    assert_eq!(val.get(), &false);
307}
308
309#[test]
310#[should_panic]
311fn test_access_other_thread() {
312    use std::thread;
313    let val = Sticky::new(true);
314    thread::spawn(move || {
315        val.get();
316    }).join()
317        .unwrap();
318}
319
320#[test]
321fn test_drop_same_thread() {
322    use std::sync::atomic::{AtomicBool, Ordering};
323    use std::sync::Arc;
324    let was_called = Arc::new(AtomicBool::new(false));
325    struct X(Arc<AtomicBool>);
326    impl Drop for X {
327        fn drop(&mut self) {
328            self.0.store(true, Ordering::SeqCst);
329        }
330    }
331    let val = Sticky::new(X(was_called.clone()));
332    mem::drop(val);
333    assert_eq!(was_called.load(Ordering::SeqCst), true);
334}
335
336#[test]
337fn test_noop_drop_elsewhere() {
338    use std::sync::atomic::{AtomicBool, Ordering};
339    use std::sync::Arc;
340    use std::thread;
341
342    let was_called = Arc::new(AtomicBool::new(false));
343
344    {
345        let was_called = was_called.clone();
346        thread::spawn(move || {
347            struct X(Arc<AtomicBool>);
348            impl Drop for X {
349                fn drop(&mut self) {
350                    self.0.store(true, Ordering::SeqCst);
351                }
352            }
353
354            let val = Sticky::new(X(was_called.clone()));
355            assert!(
356                thread::spawn(move || {
357                    // moves it here but do not deallocate
358                    val.try_get().ok();
359                }).join()
360                    .is_ok()
361            );
362
363            assert_eq!(was_called.load(Ordering::SeqCst), false);
364        }).join()
365            .unwrap();
366    }
367
368    assert_eq!(was_called.load(Ordering::SeqCst), true);
369}
370
371#[test]
372fn test_rc_sending() {
373    use std::rc::Rc;
374    use std::thread;
375    let val = Sticky::new(Rc::new(true));
376    thread::spawn(move || {
377        assert!(val.try_get().is_err());
378    }).join().unwrap();
379}