1use std::cell::UnsafeCell;
2use std::cmp;
3use std::collections::HashMap;
4use std::fmt;
5use std::marker::PhantomData;
6use std::mem;
7use std::sync::atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT};
8
9use errors::InvalidThreadAccess;
10
11fn next_item_id() -> usize {
12 static mut COUNTER: AtomicUsize = ATOMIC_USIZE_INIT;
13 unsafe { COUNTER.fetch_add(1, Ordering::SeqCst) }
14}
15
16struct Registry(HashMap<usize, (UnsafeCell<*mut ()>, Box<Fn(&UnsafeCell<*mut ()>)>)>);
17
18impl Drop for Registry {
19 fn drop(&mut self) {
20 for (_, value) in self.0.iter() {
21 (value.1)(&value.0);
22 }
23 }
24}
25
26thread_local!(static REGISTRY: UnsafeCell<Registry> = UnsafeCell::new(Registry(Default::default())));
27
28pub struct Sticky<T> {
39 item_id: usize,
40 _marker: PhantomData<*mut T>,
41}
42
43impl<T> Drop for Sticky<T> {
44 fn drop(&mut self) {
45 if mem::needs_drop::<T>() {
46 unsafe {
47 if self.is_valid() {
48 self.unsafe_take_value();
49 }
50 }
51 }
52 }
53}
54
55impl<T> Sticky<T> {
56 pub fn new(value: T) -> Self {
63 let item_id = next_item_id();
64 REGISTRY.with(|registry| unsafe {
65 (*registry.get()).0.insert(
66 item_id,
67 (
68 UnsafeCell::new(Box::into_raw(Box::new(value)) as *mut _),
69 Box::new(|cell| {
70 let b: Box<T> = Box::from_raw(*(cell.get() as *mut *mut T));
71 mem::drop(b);
72 }),
73 ),
74 );
75 });
76 Sticky {
77 item_id: item_id,
78 _marker: PhantomData,
79 }
80 }
81
82 #[inline(always)]
83 fn with_value<F: FnOnce(&UnsafeCell<Box<T>>) -> R, R>(&self, f: F) -> R {
84 REGISTRY.with(|registry| unsafe {
85 let reg = &(*(*registry).get()).0;
86 if let Some(item) = reg.get(&self.item_id) {
87 f(mem::transmute(&item.0))
88 } else {
89 panic!("trying to access wrapped value in sticky container from incorrect thread.");
90 }
91 })
92 }
93
94 #[inline(always)]
98 pub fn is_valid(&self) -> bool {
99 unsafe { REGISTRY.try_with(|registry| (*registry.get()).0.contains_key(&self.item_id)).unwrap_or(false) }
101 }
102
103 #[inline(always)]
104 fn assert_thread(&self) {
105 if !self.is_valid() {
106 panic!("trying to access wrapped value in sticky container from incorrect thread.");
107 }
108 }
109
110 pub fn into_inner(mut self) -> T {
117 self.assert_thread();
118 unsafe {
119 let rv = self.unsafe_take_value();
120 mem::forget(self);
121 rv
122 }
123 }
124
125 unsafe fn unsafe_take_value(&mut self) -> T {
126 let ptr = REGISTRY
127 .with(|registry| (*registry.get()).0.remove(&self.item_id))
128 .unwrap()
129 .0
130 .into_inner();
131 let rv = Box::from_raw(ptr as *mut T);
132 *rv
133 }
134
135 pub fn try_into_inner(self) -> Result<T, Self> {
141 if self.is_valid() {
142 Ok(self.into_inner())
143 } else {
144 Err(self)
145 }
146 }
147
148 pub fn get(&self) -> &T {
155 self.with_value(|value| unsafe { &*value.get() })
156 }
157
158 pub fn get_mut(&mut self) -> &mut T {
165 self.with_value(|value| unsafe { &mut *value.get() })
166 }
167
168 pub fn try_get(&self) -> Result<&T, InvalidThreadAccess> {
172 if self.is_valid() {
173 unsafe { Ok(self.with_value(|value| &*value.get())) }
174 } else {
175 Err(InvalidThreadAccess)
176 }
177 }
178
179 pub fn try_get_mut(&mut self) -> Result<&mut T, InvalidThreadAccess> {
183 if self.is_valid() {
184 unsafe { Ok(self.with_value(|value| &mut *value.get())) }
185 } else {
186 Err(InvalidThreadAccess)
187 }
188 }
189}
190
191impl<T> From<T> for Sticky<T> {
192 #[inline]
193 fn from(t: T) -> Sticky<T> {
194 Sticky::new(t)
195 }
196}
197
198impl<T: Clone> Clone for Sticky<T> {
199 #[inline]
200 fn clone(&self) -> Sticky<T> {
201 Sticky::new(self.get().clone())
202 }
203}
204
205impl<T: Default> Default for Sticky<T> {
206 #[inline]
207 fn default() -> Sticky<T> {
208 Sticky::new(T::default())
209 }
210}
211
212impl<T: PartialEq> PartialEq for Sticky<T> {
213 #[inline]
214 fn eq(&self, other: &Sticky<T>) -> bool {
215 *self.get() == *other.get()
216 }
217}
218
219impl<T: Eq> Eq for Sticky<T> {}
220
221impl<T: PartialOrd> PartialOrd for Sticky<T> {
222 #[inline]
223 fn partial_cmp(&self, other: &Sticky<T>) -> Option<cmp::Ordering> {
224 self.get().partial_cmp(&*other.get())
225 }
226
227 #[inline]
228 fn lt(&self, other: &Sticky<T>) -> bool {
229 *self.get() < *other.get()
230 }
231
232 #[inline]
233 fn le(&self, other: &Sticky<T>) -> bool {
234 *self.get() <= *other.get()
235 }
236
237 #[inline]
238 fn gt(&self, other: &Sticky<T>) -> bool {
239 *self.get() > *other.get()
240 }
241
242 #[inline]
243 fn ge(&self, other: &Sticky<T>) -> bool {
244 *self.get() >= *other.get()
245 }
246}
247
248impl<T: Ord> Ord for Sticky<T> {
249 #[inline]
250 fn cmp(&self, other: &Sticky<T>) -> cmp::Ordering {
251 self.get().cmp(&*other.get())
252 }
253}
254
255impl<T: fmt::Display> fmt::Display for Sticky<T> {
256 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
257 fmt::Display::fmt(self.get(), f)
258 }
259}
260
261impl<T: fmt::Debug> fmt::Debug for Sticky<T> {
262 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
263 match self.try_get() {
264 Ok(value) => f.debug_struct("Sticky").field("value", value).finish(),
265 Err(..) => {
266 struct InvalidPlaceholder;
267 impl fmt::Debug for InvalidPlaceholder {
268 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
269 f.write_str("<invalid thread>")
270 }
271 }
272
273 f.debug_struct("Sticky")
274 .field("value", &InvalidPlaceholder)
275 .finish()
276 }
277 }
278 }
279}
280
281unsafe impl<T> Sync for Sticky<T> {}
284
285unsafe impl<T> Send for Sticky<T> {}
287
288#[test]
289fn test_basic() {
290 use std::thread;
291 let val = Sticky::new(true);
292 assert_eq!(val.to_string(), "true");
293 assert_eq!(val.get(), &true);
294 assert!(val.try_get().is_ok());
295 thread::spawn(move || {
296 assert!(val.try_get().is_err());
297 }).join()
298 .unwrap();
299}
300
301#[test]
302fn test_mut() {
303 let mut val = Sticky::new(true);
304 *val.get_mut() = false;
305 assert_eq!(val.to_string(), "false");
306 assert_eq!(val.get(), &false);
307}
308
309#[test]
310#[should_panic]
311fn test_access_other_thread() {
312 use std::thread;
313 let val = Sticky::new(true);
314 thread::spawn(move || {
315 val.get();
316 }).join()
317 .unwrap();
318}
319
320#[test]
321fn test_drop_same_thread() {
322 use std::sync::atomic::{AtomicBool, Ordering};
323 use std::sync::Arc;
324 let was_called = Arc::new(AtomicBool::new(false));
325 struct X(Arc<AtomicBool>);
326 impl Drop for X {
327 fn drop(&mut self) {
328 self.0.store(true, Ordering::SeqCst);
329 }
330 }
331 let val = Sticky::new(X(was_called.clone()));
332 mem::drop(val);
333 assert_eq!(was_called.load(Ordering::SeqCst), true);
334}
335
336#[test]
337fn test_noop_drop_elsewhere() {
338 use std::sync::atomic::{AtomicBool, Ordering};
339 use std::sync::Arc;
340 use std::thread;
341
342 let was_called = Arc::new(AtomicBool::new(false));
343
344 {
345 let was_called = was_called.clone();
346 thread::spawn(move || {
347 struct X(Arc<AtomicBool>);
348 impl Drop for X {
349 fn drop(&mut self) {
350 self.0.store(true, Ordering::SeqCst);
351 }
352 }
353
354 let val = Sticky::new(X(was_called.clone()));
355 assert!(
356 thread::spawn(move || {
357 val.try_get().ok();
359 }).join()
360 .is_ok()
361 );
362
363 assert_eq!(was_called.load(Ordering::SeqCst), false);
364 }).join()
365 .unwrap();
366 }
367
368 assert_eq!(was_called.load(Ordering::SeqCst), true);
369}
370
371#[test]
372fn test_rc_sending() {
373 use std::rc::Rc;
374 use std::thread;
375 let val = Sticky::new(Rc::new(true));
376 thread::spawn(move || {
377 assert!(val.try_get().is_err());
378 }).join().unwrap();
379}