1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
// Copyright 2022-2024 Andrew D. Straw.

use std::path::Path;

use color_eyre::{
    eyre::{self as anyhow, WrapErr},
    Result,
};

use crate::h264_source::{H264Source, SeekRead, SeekableH264Source};
use mp4::MediaType;

#[derive(Debug, Clone, PartialEq)]
pub struct Mp4NalLocation {
    track_id: u32,
    sample_id: u32,
    idx: usize,
}

pub struct Mp4Source {
    mp4_reader: mp4::Mp4Reader<Box<dyn SeekRead + Send>>,
    nal_locations: Vec<Mp4NalLocation>,
}

impl SeekableH264Source for Mp4Source {
    type NalLocation = Mp4NalLocation;
    fn nal_boundaries(&mut self) -> &[Self::NalLocation] {
        &self.nal_locations
    }
    fn read_nal(&mut self, location: &Self::NalLocation) -> Result<Vec<u8>> {
        if let Some(sample) = self
            .mp4_reader
            .read_sample(location.track_id, location.sample_id)?
        {
            if !sample.bytes.is_empty() {
                let sample_nal_units = avcc_to_nalu_ebsp(sample.bytes.as_ref())?;
                return Ok(sample_nal_units[location.idx].to_vec());
            } else {
                anyhow::bail!("sample is empty");
            }
        } else {
            anyhow::bail!("sample in track disappeared");
        }
    }
}

pub(crate) fn from_reader_with_timestamp_source(
    mut mp4_reader: mp4::Mp4Reader<Box<dyn SeekRead + Send>>,
    do_decode_h264: bool,
    timestamp_source: crate::TimestampSource,
) -> Result<H264Source<Mp4Source>> {
    let timescale = mp4_reader.timescale();
    let mut video_track = None;
    for (track_id, track) in mp4_reader.tracks().iter() {
        // ignore all tracks except H264
        if track.media_type()? == MediaType::H264 {
            if video_track.is_some() {
                anyhow::bail!("only MP4 files with a single H264 video track are supported");
            }
            video_track = Some((track_id, track));
        }
    }

    let (track_id, track) = if let Some(vt) = video_track {
        vt
    } else {
        anyhow::bail!("No H264 video track found in MP4 file.");
    };

    let track_id = *track_id;

    // Iterate over every sample in the track. Typically (always?) one such MP4
    // sample corresponds to one frame of video (and often multiple NAL units).
    // Here we assume this 1:1 mapping between MP4 samples and video frames. The
    // `nal_locations` and `mp4_pts` each are indexed by sample number.
    let mut nal_locations = Vec::new();
    let mut mp4_pts = Vec::new();
    let mut sample_id = 1; // mp4 uses 1 based indexing
    let data_from_mp4_track = crate::h264_source::FromMp4Track {
        sequence_parameter_set: track.sequence_parameter_set()?.to_vec(),
        picture_parameter_set: track.picture_parameter_set()?.to_vec(),
    };
    while let Some(sample) = mp4_reader.read_sample(track_id, sample_id)? {
        if !sample.bytes.is_empty() {
            let sample_nal_units = avcc_to_nalu_ebsp(sample.bytes.as_ref())?;
            let n_nal_units = sample_nal_units.len();
            let this_pts = raw2dur(sample.start_time, timescale);
            for idx in 0..n_nal_units {
                mp4_pts.push(this_pts);
                nal_locations.push(Mp4NalLocation {
                    track_id,
                    sample_id,
                    idx,
                });
            }
        }
        sample_id += 1;
    }
    let seekable_h264_source = Mp4Source {
        mp4_reader,
        nal_locations,
    };

    let h264_source = H264Source::from_seekable_h264_source_with_timestamp_source(
        seekable_h264_source,
        do_decode_h264,
        Some(mp4_pts),
        Some(data_from_mp4_track),
        timestamp_source,
    )?;
    Ok(h264_source)
}

pub fn from_path_with_timestamp_source<P: AsRef<Path>>(
    path: P,
    do_decode_h264: bool,
    timestamp_source: crate::TimestampSource,
) -> Result<H264Source<Mp4Source>> {
    let rdr = std::fs::File::open(path.as_ref())
        .with_context(|| format!("Opening {}", path.as_ref().display()))?;
    let size = rdr.metadata()?.len();
    let buf_reader: Box<(dyn SeekRead + Send + 'static)> = Box::new(std::io::BufReader::new(rdr));
    let mp4_reader = mp4::Mp4Reader::read_header(buf_reader, size)?;

    let result = from_reader_with_timestamp_source(mp4_reader, do_decode_h264, timestamp_source)
        .with_context(|| format!("Reading MP4 file {}", path.as_ref().display()))?;
    Ok(result)
}

/// Parse sample from MP4 as NAL units.
///
/// In MP4 files, each sample buffer is multiple NAL units consisting of a
/// 4-byte length header and the data.
///
/// This function is not capable of parsing on non-NALU boundaries and must
/// contain complete NALUs. For well-formed MP4 files, this should be the case.
fn avcc_to_nalu_ebsp(mp4_sample_buffer: &[u8]) -> Result<Vec<&[u8]>> {
    let mut result = vec![];
    let mut cur_buf = mp4_sample_buffer;
    let mut total_nal_sizes = 0;
    while !cur_buf.is_empty() {
        if cur_buf.len() < 4 {
            anyhow::bail!("sample buffer is too short for NAL unit header");
        }
        let header = [cur_buf[0], cur_buf[1], cur_buf[2], cur_buf[3]];
        let sz: usize = u32::from_be_bytes(header).try_into().unwrap();
        let used = sz + 4;
        if cur_buf.len() < used {
            anyhow::bail!("AVCC buffer length: {sz}+4 but buffer {}", cur_buf.len());
        }
        total_nal_sizes += used;
        result.push(&cur_buf[4..used]);
        cur_buf = &cur_buf[used..];
    }
    if total_nal_sizes != mp4_sample_buffer.len() {
        tracing::warn!(
            "MP4 sample was {} bytes, but H264 NAL units totaled {} bytes.",
            mp4_sample_buffer.len(),
            total_nal_sizes
        );
    }
    Ok(result)
}

fn raw2dur(raw: u64, timescale: u32) -> std::time::Duration {
    std::time::Duration::from_secs_f64(raw as f64 / timescale as f64)
}

#[test]
fn test_raw_duration() {
    const TIMESCALE: u32 = 90_000;
    fn dur2raw(dur: &std::time::Duration) -> u64 {
        (dur.as_secs_f64() * TIMESCALE as f64).round() as u64
    }

    fn roundtrip(orig: u64) {
        let actual = dur2raw(&raw2dur(orig, TIMESCALE));
        assert_eq!(orig, actual);
    }
    roundtrip(0);
    roundtrip(100);
    roundtrip(1_000_000);
    roundtrip(1_000_000_000);
    roundtrip(1_000_000_000_000);
}