use std::cell::RefCell;
use once_cell::sync::OnceCell;
use parking_lot::RwLock;
use crate::{RecordingStream, StoreKind};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum RecordingScope {
Global,
ThreadLocal,
}
impl std::fmt::Display for RecordingScope {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
RecordingScope::Global => "global",
RecordingScope::ThreadLocal => "thread-local",
})
}
}
#[derive(Default)]
struct ThreadLocalRecording {
stream: Option<RecordingStream>,
}
impl ThreadLocalRecording {
fn replace(&mut self, stream: Option<RecordingStream>) -> Option<RecordingStream> {
std::mem::replace(&mut self.stream, stream)
}
fn get(&self) -> Option<RecordingStream> {
self.stream.clone()
}
}
#[cfg(any(target_os = "macos", target_os = "windows"))]
impl Drop for ThreadLocalRecording {
fn drop(&mut self) {
if let Some(stream) = self.stream.take() {
re_log::warn!("Using thread-local RecordingStream on macOS & Windows can result in data loss because of https://github.com/rerun-io/rerun/issues/3937");
std::thread::sleep(std::time::Duration::from_millis(500));
#[allow(clippy::mem_forget)] std::mem::forget(stream);
}
}
}
static GLOBAL_DATA_RECORDING: OnceCell<RwLock<Option<RecordingStream>>> = OnceCell::new();
thread_local! {
static LOCAL_DATA_RECORDING: RefCell<ThreadLocalRecording> = Default::default();
}
static GLOBAL_BLUEPRINT_RECORDING: OnceCell<RwLock<Option<RecordingStream>>> = OnceCell::new();
thread_local! {
static LOCAL_BLUEPRINT_RECORDING: RefCell<ThreadLocalRecording> = Default::default();
}
pub fn cleanup_if_forked_child() {
if let Some(global_recording) = RecordingStream::global(StoreKind::Recording) {
if global_recording.is_forked_child() {
re_log::debug!("Fork detected. Forgetting global recording");
RecordingStream::forget_global(StoreKind::Recording);
}
}
if let Some(global_blueprint) = RecordingStream::global(StoreKind::Blueprint) {
if global_blueprint.is_forked_child() {
re_log::debug!("Fork detected. Forgetting global blueprint");
RecordingStream::forget_global(StoreKind::Recording);
}
}
if let Some(thread_recording) = RecordingStream::thread_local(StoreKind::Recording) {
if thread_recording.is_forked_child() {
re_log::debug!("Fork detected. Forgetting thread-local recording");
RecordingStream::forget_thread_local(StoreKind::Recording);
}
}
if let Some(thread_blueprint) = RecordingStream::thread_local(StoreKind::Blueprint) {
if thread_blueprint.is_forked_child() {
re_log::debug!("Fork detected. Forgetting thread-local blueprint");
RecordingStream::forget_thread_local(StoreKind::Blueprint);
}
}
}
impl RecordingStream {
#[inline]
pub fn get(kind: StoreKind, overrides: Option<RecordingStream>) -> Option<RecordingStream> {
let rec = overrides.or_else(|| {
Self::get_any(RecordingScope::ThreadLocal, kind)
.or_else(|| Self::get_any(RecordingScope::Global, kind))
});
if rec.is_none() {
re_log::warn_once!(
"There is no currently active {kind} stream available \
for the current thread ({:?}): have you called `set_global()` and/or \
`set_thread_local()` first?",
std::thread::current().id(),
);
}
rec
}
#[inline]
#[doc(hidden)]
pub fn get_quiet(
kind: StoreKind,
overrides: Option<RecordingStream>,
) -> Option<RecordingStream> {
let rec = overrides.or_else(|| {
Self::get_any(RecordingScope::ThreadLocal, kind)
.or_else(|| Self::get_any(RecordingScope::Global, kind))
});
if rec.is_none() {
re_log::debug_once!(
"There is no currently active {kind} stream available \
for the current thread ({:?}): have you called `set_global()` and/or \
`set_thread_local()` first?",
std::thread::current().id(),
);
}
rec
}
#[inline]
pub fn global(kind: StoreKind) -> Option<RecordingStream> {
Self::get_any(RecordingScope::Global, kind)
}
#[inline]
pub fn set_global(kind: StoreKind, rec: Option<RecordingStream>) -> Option<RecordingStream> {
Self::set_any(RecordingScope::Global, kind, rec)
}
#[inline]
pub fn forget_global(kind: StoreKind) {
Self::forget_any(RecordingScope::Global, kind);
}
#[inline]
pub fn thread_local(kind: StoreKind) -> Option<RecordingStream> {
Self::get_any(RecordingScope::ThreadLocal, kind)
}
#[inline]
pub fn set_thread_local(
kind: StoreKind,
rec: Option<RecordingStream>,
) -> Option<RecordingStream> {
Self::set_any(RecordingScope::ThreadLocal, kind, rec)
}
#[inline]
pub fn forget_thread_local(kind: StoreKind) {
Self::forget_any(RecordingScope::ThreadLocal, kind);
}
fn get_any(scope: RecordingScope, kind: StoreKind) -> Option<RecordingStream> {
match kind {
StoreKind::Recording => match scope {
RecordingScope::Global => GLOBAL_DATA_RECORDING
.get_or_init(Default::default)
.read()
.clone(),
RecordingScope::ThreadLocal => LOCAL_DATA_RECORDING.with(|rec| rec.borrow().get()),
},
StoreKind::Blueprint => match scope {
RecordingScope::Global => GLOBAL_BLUEPRINT_RECORDING
.get_or_init(Default::default)
.read()
.clone(),
RecordingScope::ThreadLocal => {
LOCAL_BLUEPRINT_RECORDING.with(|rec| rec.borrow().get())
}
},
}
}
fn set_any(
scope: RecordingScope,
kind: StoreKind,
rec: Option<RecordingStream>,
) -> Option<RecordingStream> {
match kind {
StoreKind::Recording => match scope {
RecordingScope::Global => std::mem::replace(
&mut *GLOBAL_DATA_RECORDING.get_or_init(Default::default).write(),
rec,
),
RecordingScope::ThreadLocal => {
LOCAL_DATA_RECORDING.with(|cell| cell.borrow_mut().replace(rec))
}
},
StoreKind::Blueprint => match scope {
RecordingScope::Global => std::mem::replace(
&mut *GLOBAL_BLUEPRINT_RECORDING
.get_or_init(Default::default)
.write(),
rec,
),
RecordingScope::ThreadLocal => {
LOCAL_BLUEPRINT_RECORDING.with(|cell| cell.borrow_mut().replace(rec))
}
},
}
}
fn forget_any(scope: RecordingScope, kind: StoreKind) {
#![allow(clippy::mem_forget)] match kind {
StoreKind::Recording => match scope {
RecordingScope::Global => {
if let Some(global) = GLOBAL_DATA_RECORDING.get() {
std::mem::forget(global.write().take());
}
}
RecordingScope::ThreadLocal => LOCAL_DATA_RECORDING.with(|cell| {
std::mem::forget(cell.take());
}),
},
StoreKind::Blueprint => match scope {
RecordingScope::Global => {
if let Some(global) = GLOBAL_BLUEPRINT_RECORDING.get() {
std::mem::forget(global.write().take());
}
}
RecordingScope::ThreadLocal => LOCAL_BLUEPRINT_RECORDING.with(|cell| {
std::mem::forget(cell.take());
}),
},
}
}
}
#[cfg(test)]
mod tests {
use crate::RecordingStreamBuilder;
use super::*;
#[test]
fn fallbacks() {
fn check_store_id(expected: &RecordingStream, got: Option<RecordingStream>) {
assert_eq!(
expected.store_info().unwrap().store_id,
got.unwrap().store_info().unwrap().store_id
);
}
assert!(RecordingStream::get(StoreKind::Recording, None).is_none());
assert!(RecordingStream::get(StoreKind::Blueprint, None).is_none());
let explicit = RecordingStreamBuilder::new("rerun_example_explicit")
.buffered()
.unwrap();
check_store_id(
&explicit,
RecordingStream::get(StoreKind::Recording, explicit.clone().into()),
);
check_store_id(
&explicit,
RecordingStream::get(StoreKind::Blueprint, explicit.clone().into()),
);
let global_data = RecordingStreamBuilder::new("rerun_example_global_data")
.buffered()
.unwrap();
assert!(
RecordingStream::set_global(StoreKind::Recording, Some(global_data.clone())).is_none()
);
let global_blueprint = RecordingStreamBuilder::new("rerun_example_global_blueprint")
.buffered()
.unwrap();
assert!(
RecordingStream::set_global(StoreKind::Blueprint, Some(global_blueprint.clone()))
.is_none()
);
check_store_id(
&global_data,
RecordingStream::get(StoreKind::Recording, None),
);
check_store_id(
&global_blueprint,
RecordingStream::get(StoreKind::Blueprint, None),
);
check_store_id(
&global_data,
RecordingStream::set_global(StoreKind::Recording, Some(global_data.clone())),
);
check_store_id(
&global_blueprint,
RecordingStream::set_global(StoreKind::Blueprint, Some(global_blueprint.clone())),
);
std::thread::Builder::new()
.spawn({
let global_data = global_data.clone();
let global_blueprint = global_blueprint.clone();
move || {
check_store_id(
&global_data,
RecordingStream::get(StoreKind::Recording, None),
);
check_store_id(
&global_blueprint,
RecordingStream::get(StoreKind::Blueprint, None),
);
let local_data = RecordingStreamBuilder::new("rerun_example_local_data")
.buffered()
.unwrap();
assert!(RecordingStream::set_thread_local(
StoreKind::Recording,
Some(local_data.clone())
)
.is_none());
let local_blueprint =
RecordingStreamBuilder::new("rerun_example_local_blueprint")
.buffered()
.unwrap();
assert!(RecordingStream::set_thread_local(
StoreKind::Blueprint,
Some(local_blueprint.clone())
)
.is_none());
check_store_id(
&local_data,
RecordingStream::get(StoreKind::Recording, None),
);
check_store_id(
&local_blueprint,
RecordingStream::get(StoreKind::Blueprint, None),
);
check_store_id(
&explicit,
RecordingStream::get(StoreKind::Recording, explicit.clone().into()),
);
check_store_id(
&explicit,
RecordingStream::get(StoreKind::Blueprint, explicit.clone().into()),
);
}
})
.unwrap()
.join()
.unwrap();
check_store_id(
&global_data,
RecordingStream::get(StoreKind::Recording, None),
);
check_store_id(
&global_blueprint,
RecordingStream::get(StoreKind::Blueprint, None),
);
let local_data = RecordingStreamBuilder::new("rerun_example_local_data")
.buffered()
.unwrap();
assert!(
RecordingStream::set_thread_local(StoreKind::Recording, Some(local_data.clone()))
.is_none()
);
let local_blueprint = RecordingStreamBuilder::new("rerun_example_local_blueprint")
.buffered()
.unwrap();
assert!(RecordingStream::set_thread_local(
StoreKind::Blueprint,
Some(local_blueprint.clone())
)
.is_none());
check_store_id(
&global_data,
RecordingStream::set_global(StoreKind::Recording, None),
);
check_store_id(
&global_blueprint,
RecordingStream::set_global(StoreKind::Blueprint, None),
);
check_store_id(
&local_data,
RecordingStream::get(StoreKind::Recording, None),
);
check_store_id(
&local_blueprint,
RecordingStream::get(StoreKind::Blueprint, None),
);
}
}