1#![warn(missing_docs)]
72
73#[macro_use]
74extern crate lazy_static;
75
76mod thread_id;
77mod unreachable;
78
79use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
80use std::sync::Mutex;
81use std::marker::PhantomData;
82use std::cell::UnsafeCell;
83use std::fmt;
84use std::iter::Chain;
85use std::option::IntoIter as OptionIter;
86use std::panic::UnwindSafe;
87use unreachable::{UncheckedOptionExt, UncheckedResultExt};
88
89pub struct ThreadLocal<T: ?Sized + Send> {
93 table: AtomicPtr<Table<T>>,
95
96 lock: Mutex<usize>,
100
101 marker: PhantomData<T>,
103}
104
105struct Table<T: ?Sized + Send> {
106 entries: Box<[TableEntry<T>]>,
108
109 hash_bits: usize,
111
112 prev: Option<Box<Table<T>>>,
114}
115
116struct TableEntry<T: ?Sized + Send> {
117 owner: AtomicUsize,
119
120 data: UnsafeCell<Option<Box<T>>>,
123}
124
125unsafe impl<T: ?Sized + Send> Sync for ThreadLocal<T> {}
127
128impl<T: ?Sized + Send> Default for ThreadLocal<T> {
129 fn default() -> ThreadLocal<T> {
130 ThreadLocal::new()
131 }
132}
133
134impl<T: ?Sized + Send> Drop for ThreadLocal<T> {
135 fn drop(&mut self) {
136 unsafe {
137 Box::from_raw(self.table.load(Ordering::Relaxed));
138 }
139 }
140}
141
142impl<T: ?Sized + Send> Clone for TableEntry<T> {
144 fn clone(&self) -> TableEntry<T> {
145 TableEntry {
146 owner: AtomicUsize::new(0),
147 data: UnsafeCell::new(None),
148 }
149 }
150}
151
152#[cfg(target_pointer_width = "32")]
154#[inline]
155fn hash(id: usize, bits: usize) -> usize {
156 id.wrapping_mul(0x9E3779B9) >> (32 - bits)
157}
158#[cfg(target_pointer_width = "64")]
159#[inline]
160fn hash(id: usize, bits: usize) -> usize {
161 id.wrapping_mul(0x9E37_79B9_7F4A_7C15) >> (64 - bits)
162}
163
164impl<T: ?Sized + Send> ThreadLocal<T> {
165 pub fn new() -> ThreadLocal<T> {
167 let entry = TableEntry {
168 owner: AtomicUsize::new(0),
169 data: UnsafeCell::new(None),
170 };
171 let table = Table {
172 entries: vec![entry; 2].into_boxed_slice(),
173 hash_bits: 1,
174 prev: None,
175 };
176 ThreadLocal {
177 table: AtomicPtr::new(Box::into_raw(Box::new(table))),
178 lock: Mutex::new(0),
179 marker: PhantomData,
180 }
181 }
182
183 pub fn get(&self) -> Option<&T> {
185 let id = thread_id::get();
186 self.get_fast(id)
187 }
188
189 pub fn get_or<F>(&self, create: F) -> &T
192 where
193 F: FnOnce() -> Box<T>,
194 {
195 unsafe {
196 self.get_or_try(|| Ok::<Box<T>, ()>(create()))
197 .unchecked_unwrap_ok()
198 }
199 }
200
201 pub fn get_or_try<F, E>(&self, create: F) -> Result<&T, E>
205 where
206 F: FnOnce() -> Result<Box<T>, E>,
207 {
208 let id = thread_id::get();
209 match self.get_fast(id) {
210 Some(x) => Ok(x),
211 None => Ok(self.insert(id, try!(create()), true)),
212 }
213 }
214
215 fn lookup(id: usize, table: &Table<T>) -> Option<&UnsafeCell<Option<Box<T>>>> {
217 for entry in table.entries.iter().cycle().skip(hash(id, table.hash_bits)) {
222 let owner = entry.owner.load(Ordering::Relaxed);
223 if owner == id {
224 return Some(&entry.data);
225 }
226 if owner == 0 {
227 return None;
228 }
229 }
230 unreachable!();
231 }
232
233 fn get_fast(&self, id: usize) -> Option<&T> {
235 let table = unsafe { &*self.table.load(Ordering::Relaxed) };
236 match Self::lookup(id, table) {
237 Some(x) => unsafe { Some((*x.get()).as_ref().unchecked_unwrap()) },
238 None => self.get_slow(id, table),
239 }
240 }
241
242 #[cold]
245 fn get_slow(&self, id: usize, table_top: &Table<T>) -> Option<&T> {
246 let mut current = &table_top.prev;
247 while let Some(ref table) = *current {
248 if let Some(x) = Self::lookup(id, table) {
249 let data = unsafe { (*x.get()).take().unchecked_unwrap() };
250 return Some(self.insert(id, data, false));
251 }
252 current = &table.prev;
253 }
254 None
255 }
256
257 #[cold]
258 fn insert(&self, id: usize, data: Box<T>, new: bool) -> &T {
259 let mut count = self.lock.lock().unwrap();
262 if new {
263 *count += 1;
264 }
265 let table_raw = self.table.load(Ordering::Relaxed);
266 let table = unsafe { &*table_raw };
267
268 let table = if *count > table.entries.len() * 3 / 4 {
272 let entry = TableEntry {
273 owner: AtomicUsize::new(0),
274 data: UnsafeCell::new(None),
275 };
276 let new_table = Box::into_raw(Box::new(Table {
277 entries: vec![entry; table.entries.len() * 2].into_boxed_slice(),
278 hash_bits: table.hash_bits + 1,
279 prev: unsafe { Some(Box::from_raw(table_raw)) },
280 }));
281 self.table.store(new_table, Ordering::Release);
282 unsafe { &*new_table }
283 } else {
284 table
285 };
286
287 for entry in table.entries.iter().cycle().skip(hash(id, table.hash_bits)) {
289 let owner = entry.owner.load(Ordering::Relaxed);
290 if owner == 0 {
291 unsafe {
292 entry.owner.store(id, Ordering::Relaxed);
293 *entry.data.get() = Some(data);
294 return (*entry.data.get()).as_ref().unchecked_unwrap();
295 }
296 }
297 if owner == id {
298 unsafe {
303 return (*entry.data.get()).as_ref().unchecked_unwrap();
304 }
305 }
306 }
307 unreachable!();
308 }
309
310 pub fn iter_mut(&mut self) -> IterMut<T> {
316 let raw = RawIter {
317 remaining: *self.lock.lock().unwrap(),
318 index: 0,
319 table: self.table.load(Ordering::Relaxed),
320 };
321 IterMut {
322 raw: raw,
323 marker: PhantomData,
324 }
325 }
326
327 pub fn clear(&mut self) {
334 *self = ThreadLocal::new();
335 }
336}
337
338impl<T: ?Sized + Send> IntoIterator for ThreadLocal<T> {
339 type Item = Box<T>;
340 type IntoIter = IntoIter<T>;
341
342 fn into_iter(self) -> IntoIter<T> {
343 let raw = RawIter {
344 remaining: *self.lock.lock().unwrap(),
345 index: 0,
346 table: self.table.load(Ordering::Relaxed),
347 };
348 IntoIter {
349 raw: raw,
350 _thread_local: self,
351 }
352 }
353}
354
355impl<'a, T: ?Sized + Send + 'a> IntoIterator for &'a mut ThreadLocal<T> {
356 type Item = &'a mut Box<T>;
357 type IntoIter = IterMut<'a, T>;
358
359 fn into_iter(self) -> IterMut<'a, T> {
360 self.iter_mut()
361 }
362}
363
364impl<T: Send + Default> ThreadLocal<T> {
365 pub fn get_default(&self) -> &T {
368 self.get_or(|| Box::new(T::default()))
369 }
370}
371
372impl<T: ?Sized + Send + fmt::Debug> fmt::Debug for ThreadLocal<T> {
373 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
374 write!(f, "ThreadLocal {{ local_data: {:?} }}", self.get())
375 }
376}
377
378impl<T: ?Sized + Send + UnwindSafe> UnwindSafe for ThreadLocal<T> {}
379
380struct RawIter<T: ?Sized + Send> {
381 remaining: usize,
382 index: usize,
383 table: *const Table<T>,
384}
385
386impl<T: ?Sized + Send> RawIter<T> {
387 fn next(&mut self) -> Option<*mut Option<Box<T>>> {
388 if self.remaining == 0 {
389 return None;
390 }
391
392 loop {
393 let entries = unsafe { &(*self.table).entries[..] };
394 while self.index < entries.len() {
395 let val = entries[self.index].data.get();
396 self.index += 1;
397 if unsafe { (*val).is_some() } {
398 self.remaining -= 1;
399 return Some(val);
400 }
401 }
402 self.index = 0;
403 self.table = unsafe { &**(*self.table).prev.as_ref().unchecked_unwrap() };
404 }
405 }
406}
407
408pub struct IterMut<'a, T: ?Sized + Send + 'a> {
410 raw: RawIter<T>,
411 marker: PhantomData<&'a mut ThreadLocal<T>>,
412}
413
414impl<'a, T: ?Sized + Send + 'a> Iterator for IterMut<'a, T> {
415 type Item = &'a mut Box<T>;
416
417 fn next(&mut self) -> Option<&'a mut Box<T>> {
418 self.raw.next().map(|x| unsafe {
419 (*x).as_mut().unchecked_unwrap()
420 })
421 }
422
423 fn size_hint(&self) -> (usize, Option<usize>) {
424 (self.raw.remaining, Some(self.raw.remaining))
425 }
426}
427
428impl<'a, T: ?Sized + Send + 'a> ExactSizeIterator for IterMut<'a, T> {}
429
430pub struct IntoIter<T: ?Sized + Send> {
432 raw: RawIter<T>,
433 _thread_local: ThreadLocal<T>,
434}
435
436impl<T: ?Sized + Send> Iterator for IntoIter<T> {
437 type Item = Box<T>;
438
439 fn next(&mut self) -> Option<Box<T>> {
440 self.raw.next().map(
441 |x| unsafe { (*x).take().unchecked_unwrap() },
442 )
443 }
444
445 fn size_hint(&self) -> (usize, Option<usize>) {
446 (self.raw.remaining, Some(self.raw.remaining))
447 }
448}
449
450impl<T: ?Sized + Send> ExactSizeIterator for IntoIter<T> {}
451
452pub struct CachedThreadLocal<T: ?Sized + Send> {
458 owner: AtomicUsize,
459 local: UnsafeCell<Option<Box<T>>>,
460 global: ThreadLocal<T>,
461}
462
463unsafe impl<T: ?Sized + Send> Sync for CachedThreadLocal<T> {}
465
466impl<T: ?Sized + Send> Default for CachedThreadLocal<T> {
467 fn default() -> CachedThreadLocal<T> {
468 CachedThreadLocal::new()
469 }
470}
471
472impl<T: ?Sized + Send> CachedThreadLocal<T> {
473 pub fn new() -> CachedThreadLocal<T> {
475 CachedThreadLocal {
476 owner: AtomicUsize::new(0),
477 local: UnsafeCell::new(None),
478 global: ThreadLocal::new(),
479 }
480 }
481
482 pub fn get(&self) -> Option<&T> {
484 let id = thread_id::get();
485 let owner = self.owner.load(Ordering::Relaxed);
486 if owner == id {
487 return unsafe { Some((*self.local.get()).as_ref().unchecked_unwrap()) };
488 }
489 if owner == 0 {
490 return None;
491 }
492 self.global.get_fast(id)
493 }
494
495 #[inline(always)]
498 pub fn get_or<F>(&self, create: F) -> &T
499 where
500 F: FnOnce() -> Box<T>,
501 {
502 unsafe {
503 self.get_or_try(|| Ok::<Box<T>, ()>(create()))
504 .unchecked_unwrap_ok()
505 }
506 }
507
508 pub fn get_or_try<F, E>(&self, create: F) -> Result<&T, E>
512 where
513 F: FnOnce() -> Result<Box<T>, E>,
514 {
515 let id = thread_id::get();
516 let owner = self.owner.load(Ordering::Relaxed);
517 if owner == id {
518 return Ok(unsafe { (*self.local.get()).as_ref().unchecked_unwrap() });
519 }
520 self.get_or_try_slow(id, owner, create)
521 }
522
523 #[cold]
524 #[inline(never)]
525 fn get_or_try_slow<F, E>(&self, id: usize, owner: usize, create: F) -> Result<&T, E>
526 where
527 F: FnOnce() -> Result<Box<T>, E>,
528 {
529 if owner == 0 && self.owner.compare_and_swap(0, id, Ordering::Relaxed) == 0 {
530 unsafe {
531 (*self.local.get()) = Some(try!(create()));
532 return Ok((*self.local.get()).as_ref().unchecked_unwrap());
533 }
534 }
535 match self.global.get_fast(id) {
536 Some(x) => Ok(x),
537 None => Ok(self.global.insert(id, try!(create()), true)),
538 }
539 }
540
541 pub fn iter_mut(&mut self) -> CachedIterMut<T> {
547 unsafe {
548 (*self.local.get()).as_mut().into_iter().chain(
549 self.global
550 .iter_mut(),
551 )
552 }
553 }
554
555 pub fn clear(&mut self) {
562 *self = CachedThreadLocal::new();
563 }
564}
565
566impl<T: ?Sized + Send> IntoIterator for CachedThreadLocal<T> {
567 type Item = Box<T>;
568 type IntoIter = CachedIntoIter<T>;
569
570 fn into_iter(self) -> CachedIntoIter<T> {
571 unsafe {
572 (*self.local.get()).take().into_iter().chain(
573 self.global
574 .into_iter(),
575 )
576 }
577 }
578}
579
580impl<'a, T: ?Sized + Send + 'a> IntoIterator for &'a mut CachedThreadLocal<T> {
581 type Item = &'a mut Box<T>;
582 type IntoIter = CachedIterMut<'a, T>;
583
584 fn into_iter(self) -> CachedIterMut<'a, T> {
585 self.iter_mut()
586 }
587}
588
589impl<T: Send + Default> CachedThreadLocal<T> {
590 pub fn get_default(&self) -> &T {
593 self.get_or(|| Box::new(T::default()))
594 }
595}
596
597impl<T: ?Sized + Send + fmt::Debug> fmt::Debug for CachedThreadLocal<T> {
598 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
599 write!(f, "ThreadLocal {{ local_data: {:?} }}", self.get())
600 }
601}
602
603pub type CachedIterMut<'a, T> = Chain<OptionIter<&'a mut Box<T>>, IterMut<'a, T>>;
605
606pub type CachedIntoIter<T> = Chain<OptionIter<Box<T>>, IntoIter<T>>;
608
609impl<T: ?Sized + Send + UnwindSafe> UnwindSafe for CachedThreadLocal<T> {}
610
611#[cfg(test)]
612mod tests {
613 use std::cell::RefCell;
614 use std::sync::Arc;
615 use std::sync::atomic::AtomicUsize;
616 use std::sync::atomic::Ordering::Relaxed;
617 use std::thread;
618 use super::{ThreadLocal, CachedThreadLocal};
619
620 fn make_create() -> Arc<Fn() -> Box<usize> + Send + Sync> {
621 let count = AtomicUsize::new(0);
622 Arc::new(move || Box::new(count.fetch_add(1, Relaxed)))
623 }
624
625 #[test]
626 fn same_thread() {
627 let create = make_create();
628 let mut tls = ThreadLocal::new();
629 assert_eq!(None, tls.get());
630 assert_eq!("ThreadLocal { local_data: None }", format!("{:?}", &tls));
631 assert_eq!(0, *tls.get_or(|| create()));
632 assert_eq!(Some(&0), tls.get());
633 assert_eq!(0, *tls.get_or(|| create()));
634 assert_eq!(Some(&0), tls.get());
635 assert_eq!(0, *tls.get_or(|| create()));
636 assert_eq!(Some(&0), tls.get());
637 assert_eq!("ThreadLocal { local_data: Some(0) }", format!("{:?}", &tls));
638 tls.clear();
639 assert_eq!(None, tls.get());
640 }
641
642 #[test]
643 fn same_thread_cached() {
644 let create = make_create();
645 let mut tls = CachedThreadLocal::new();
646 assert_eq!(None, tls.get());
647 assert_eq!("ThreadLocal { local_data: None }", format!("{:?}", &tls));
648 assert_eq!(0, *tls.get_or(|| create()));
649 assert_eq!(Some(&0), tls.get());
650 assert_eq!(0, *tls.get_or(|| create()));
651 assert_eq!(Some(&0), tls.get());
652 assert_eq!(0, *tls.get_or(|| create()));
653 assert_eq!(Some(&0), tls.get());
654 assert_eq!("ThreadLocal { local_data: Some(0) }", format!("{:?}", &tls));
655 tls.clear();
656 assert_eq!(None, tls.get());
657 }
658
659 #[test]
660 fn different_thread() {
661 let create = make_create();
662 let tls = Arc::new(ThreadLocal::new());
663 assert_eq!(None, tls.get());
664 assert_eq!(0, *tls.get_or(|| create()));
665 assert_eq!(Some(&0), tls.get());
666
667 let tls2 = tls.clone();
668 let create2 = create.clone();
669 thread::spawn(move || {
670 assert_eq!(None, tls2.get());
671 assert_eq!(1, *tls2.get_or(|| create2()));
672 assert_eq!(Some(&1), tls2.get());
673 }).join()
674 .unwrap();
675
676 assert_eq!(Some(&0), tls.get());
677 assert_eq!(0, *tls.get_or(|| create()));
678 }
679
680 #[test]
681 fn different_thread_cached() {
682 let create = make_create();
683 let tls = Arc::new(CachedThreadLocal::new());
684 assert_eq!(None, tls.get());
685 assert_eq!(0, *tls.get_or(|| create()));
686 assert_eq!(Some(&0), tls.get());
687
688 let tls2 = tls.clone();
689 let create2 = create.clone();
690 thread::spawn(move || {
691 assert_eq!(None, tls2.get());
692 assert_eq!(1, *tls2.get_or(|| create2()));
693 assert_eq!(Some(&1), tls2.get());
694 }).join()
695 .unwrap();
696
697 assert_eq!(Some(&0), tls.get());
698 assert_eq!(0, *tls.get_or(|| create()));
699 }
700
701 #[test]
702 fn iter() {
703 let tls = Arc::new(ThreadLocal::new());
704 tls.get_or(|| Box::new(1));
705
706 let tls2 = tls.clone();
707 thread::spawn(move || {
708 tls2.get_or(|| Box::new(2));
709 let tls3 = tls2.clone();
710 thread::spawn(move || { tls3.get_or(|| Box::new(3)); })
711 .join()
712 .unwrap();
713 }).join()
714 .unwrap();
715
716 let mut tls = Arc::try_unwrap(tls).unwrap();
717 let mut v = tls.iter_mut().map(|x| **x).collect::<Vec<i32>>();
718 v.sort();
719 assert_eq!(vec![1, 2, 3], v);
720 let mut v = tls.into_iter().map(|x| *x).collect::<Vec<i32>>();
721 v.sort();
722 assert_eq!(vec![1, 2, 3], v);
723 }
724
725 #[test]
726 fn iter_cached() {
727 let tls = Arc::new(CachedThreadLocal::new());
728 tls.get_or(|| Box::new(1));
729
730 let tls2 = tls.clone();
731 thread::spawn(move || {
732 tls2.get_or(|| Box::new(2));
733 let tls3 = tls2.clone();
734 thread::spawn(move || { tls3.get_or(|| Box::new(3)); })
735 .join()
736 .unwrap();
737 }).join()
738 .unwrap();
739
740 let mut tls = Arc::try_unwrap(tls).unwrap();
741 let mut v = tls.iter_mut().map(|x| **x).collect::<Vec<i32>>();
742 v.sort();
743 assert_eq!(vec![1, 2, 3], v);
744 let mut v = tls.into_iter().map(|x| *x).collect::<Vec<i32>>();
745 v.sort();
746 assert_eq!(vec![1, 2, 3], v);
747 }
748
749 #[test]
750 fn is_sync() {
751 fn foo<T: Sync>() {}
752 foo::<ThreadLocal<String>>();
753 foo::<ThreadLocal<RefCell<String>>>();
754 foo::<CachedThreadLocal<String>>();
755 foo::<CachedThreadLocal<RefCell<String>>>();
756 }
757}