1use std::cell::UnsafeCell;
2use std::cmp;
3use std::fmt;
4use std::mem;
5use std::mem::ManuallyDrop;
6use std::sync::atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT};
7
8use errors::InvalidThreadAccess;
9
10fn next_thread_id() -> usize {
11 static mut COUNTER: AtomicUsize = ATOMIC_USIZE_INIT;
12 unsafe { COUNTER.fetch_add(1, Ordering::SeqCst) }
13}
14
15pub(crate) fn get_thread_id() -> usize {
16 thread_local!(static THREAD_ID: usize = next_thread_id());
17 THREAD_ID.with(|&x| x)
18}
19
20pub struct Fragile<T> {
29 value: ManuallyDrop<UnsafeCell<Box<T>>>,
30 thread_id: usize,
31}
32
33impl<T> Fragile<T> {
34 pub fn new(value: T) -> Self {
41 Fragile {
42 value: ManuallyDrop::new(UnsafeCell::new(Box::new(value))),
43 thread_id: get_thread_id(),
44 }
45 }
46
47 pub fn is_valid(&self) -> bool {
51 get_thread_id() == self.thread_id
52 }
53
54 #[inline(always)]
55 fn assert_thread(&self) {
56 if !self.is_valid() {
57 panic!("trying to access wrapped value in fragile container from incorrect thread.");
58 }
59 }
60
61 pub fn into_inner(mut self) -> T {
68 self.assert_thread();
69 unsafe {
70 let value = mem::replace(&mut self.value, mem::uninitialized());
71 mem::forget(self);
72 *ManuallyDrop::into_inner(value).into_inner()
73 }
74 }
75
76 pub fn try_into_inner(self) -> Result<T, Self> {
82 if get_thread_id() == self.thread_id {
83 Ok(self.into_inner())
84 } else {
85 Err(self)
86 }
87 }
88
89 pub fn get(&self) -> &T {
96 self.assert_thread();
97 unsafe { &*self.value.get() }
98 }
99
100 pub fn get_mut(&mut self) -> &mut T {
107 self.assert_thread();
108 unsafe { &mut *self.value.get() }
109 }
110
111 pub fn try_get(&self) -> Result<&T, InvalidThreadAccess> {
115 if get_thread_id() == self.thread_id {
116 unsafe { Ok(&*self.value.get()) }
117 } else {
118 Err(InvalidThreadAccess)
119 }
120 }
121
122 pub fn try_get_mut(&mut self) -> Result<&mut T, InvalidThreadAccess> {
126 if get_thread_id() == self.thread_id {
127 unsafe { Ok(&mut *self.value.get()) }
128 } else {
129 Err(InvalidThreadAccess)
130 }
131 }
132}
133
134impl<T> Drop for Fragile<T> {
135 fn drop(&mut self) {
136 if mem::needs_drop::<T>() {
137 if get_thread_id() == self.thread_id {
138 unsafe { ManuallyDrop::drop(&mut self.value) }
139 } else {
140 panic!("destructor of fragile object ran on wrong thread");
141 }
142 }
143 }
144}
145
146impl<T> From<T> for Fragile<T> {
147 #[inline]
148 fn from(t: T) -> Fragile<T> {
149 Fragile::new(t)
150 }
151}
152
153impl<T: Clone> Clone for Fragile<T> {
154 #[inline]
155 fn clone(&self) -> Fragile<T> {
156 Fragile::new(self.get().clone())
157 }
158}
159
160impl<T: Default> Default for Fragile<T> {
161 #[inline]
162 fn default() -> Fragile<T> {
163 Fragile::new(T::default())
164 }
165}
166
167impl<T: PartialEq> PartialEq for Fragile<T> {
168 #[inline]
169 fn eq(&self, other: &Fragile<T>) -> bool {
170 *self.get() == *other.get()
171 }
172}
173
174impl<T: Eq> Eq for Fragile<T> {}
175
176impl<T: PartialOrd> PartialOrd for Fragile<T> {
177 #[inline]
178 fn partial_cmp(&self, other: &Fragile<T>) -> Option<cmp::Ordering> {
179 self.get().partial_cmp(&*other.get())
180 }
181
182 #[inline]
183 fn lt(&self, other: &Fragile<T>) -> bool {
184 *self.get() < *other.get()
185 }
186
187 #[inline]
188 fn le(&self, other: &Fragile<T>) -> bool {
189 *self.get() <= *other.get()
190 }
191
192 #[inline]
193 fn gt(&self, other: &Fragile<T>) -> bool {
194 *self.get() > *other.get()
195 }
196
197 #[inline]
198 fn ge(&self, other: &Fragile<T>) -> bool {
199 *self.get() >= *other.get()
200 }
201}
202
203impl<T: Ord> Ord for Fragile<T> {
204 #[inline]
205 fn cmp(&self, other: &Fragile<T>) -> cmp::Ordering {
206 self.get().cmp(&*other.get())
207 }
208}
209
210impl<T: fmt::Display> fmt::Display for Fragile<T> {
211 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
212 fmt::Display::fmt(self.get(), f)
213 }
214}
215
216impl<T: fmt::Debug> fmt::Debug for Fragile<T> {
217 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
218 match self.try_get() {
219 Ok(value) => f.debug_struct("Fragile").field("value", value).finish(),
220 Err(..) => {
221 struct InvalidPlaceholder;
222 impl fmt::Debug for InvalidPlaceholder {
223 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
224 f.write_str("<invalid thread>")
225 }
226 }
227
228 f.debug_struct("Fragile")
229 .field("value", &InvalidPlaceholder)
230 .finish()
231 }
232 }
233 }
234}
235
236unsafe impl<T> Sync for Fragile<T> {}
240
241unsafe impl<T> Send for Fragile<T> {}
243
244#[test]
245fn test_basic() {
246 use std::thread;
247 let val = Fragile::new(true);
248 assert_eq!(val.to_string(), "true");
249 assert_eq!(val.get(), &true);
250 assert!(val.try_get().is_ok());
251 thread::spawn(move || {
252 assert!(val.try_get().is_err());
253 }).join()
254 .unwrap();
255}
256
257#[test]
258fn test_mut() {
259 let mut val = Fragile::new(true);
260 *val.get_mut() = false;
261 assert_eq!(val.to_string(), "false");
262 assert_eq!(val.get(), &false);
263}
264
265#[test]
266#[should_panic]
267fn test_access_other_thread() {
268 use std::thread;
269 let val = Fragile::new(true);
270 thread::spawn(move || {
271 val.get();
272 }).join()
273 .unwrap();
274}
275
276#[test]
277fn test_noop_drop_elsewhere() {
278 use std::thread;
279 let val = Fragile::new(true);
280 thread::spawn(move || {
281 val.try_get().ok();
283 }).join()
284 .unwrap();
285}
286
287#[test]
288fn test_panic_on_drop_elsewhere() {
289 use std::sync::atomic::{AtomicBool, Ordering};
290 use std::sync::Arc;
291 use std::thread;
292 let was_called = Arc::new(AtomicBool::new(false));
293 struct X(Arc<AtomicBool>);
294 impl Drop for X {
295 fn drop(&mut self) {
296 self.0.store(true, Ordering::SeqCst);
297 }
298 }
299 let val = Fragile::new(X(was_called.clone()));
300 assert!(
301 thread::spawn(move || {
302 val.try_get().ok();
303 }).join()
304 .is_err()
305 );
306 assert_eq!(was_called.load(Ordering::SeqCst), false);
307}
308
309#[test]
310fn test_rc_sending() {
311 use std::rc::Rc;
312 use std::thread;
313 use std::sync::mpsc::channel;
314
315 let val = Fragile::new(Rc::new(true));
316 let (tx, rx) = channel();
317
318 let thread = thread::spawn(move || {
319 assert!(val.try_get().is_err());
320 let here = val;
321 tx.send(here).unwrap();
322 });
323
324 let rv = rx.recv().unwrap();
325 assert!(**rv.get());
326
327 thread.join().unwrap();
328}