1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
//! Converts YUV / RGB images to NAL packets.

use crate::error::NativeErrorExt;
use crate::formats::YUVSource;
use crate::{Error, Timestamp};
use openh264_sys2::{
    videoFormatI420, EVideoFormatType, ISVCEncoder, ISVCEncoderVtbl, SEncParamBase, SEncParamExt, SFrameBSInfo, SLayerBSInfo,
    SSourcePicture, WelsCreateSVCEncoder, WelsDestroySVCEncoder, ENCODER_OPTION, ENCODER_OPTION_DATAFORMAT,
    ENCODER_OPTION_TRACE_LEVEL, RC_MODES, VIDEO_CODING_LAYER, WELS_LOG_DETAIL, WELS_LOG_QUIET,
};
use std::os::raw::{c_int, c_uchar, c_void};
use std::ptr::{addr_of_mut, null, null_mut};

/// Convenience wrapper with guaranteed function pointers for easy access.
///
/// This struct automatically handles `WelsCreateSVCEncoder` and `WelsDestroySVCEncoder`.
#[rustfmt::skip]
#[allow(non_snake_case)]
#[derive(Debug)]
pub struct EncoderRawAPI {
    encoder_ptr: *mut *const ISVCEncoderVtbl,
    initialize: unsafe extern "C" fn(arg1: *mut ISVCEncoder, pParam: *const SEncParamBase) -> c_int,
    initialize_ext: unsafe extern "C" fn(arg1: *mut ISVCEncoder, pParam: *const SEncParamExt) -> c_int,
    get_default_params: unsafe extern "C" fn(arg1: *mut ISVCEncoder, pParam: *mut SEncParamExt) -> c_int,
    uninitialize: unsafe extern "C" fn(arg1: *mut ISVCEncoder) -> c_int,
    encode_frame: unsafe extern "C" fn(arg1: *mut ISVCEncoder, kpSrcPic: *const SSourcePicture, pBsInfo: *mut SFrameBSInfo) -> c_int,
    encode_parameter_sets: unsafe extern "C" fn(arg1: *mut ISVCEncoder, pBsInfo: *mut SFrameBSInfo) -> c_int,
    force_intra_frame: unsafe extern "C" fn(arg1: *mut ISVCEncoder, bIDR: bool) -> c_int,
    set_option: unsafe extern "C" fn(arg1: *mut ISVCEncoder, eOptionId: ENCODER_OPTION, pOption: *mut c_void) -> c_int,
    get_option: unsafe extern "C" fn(arg1: *mut ISVCEncoder, eOptionId: ENCODER_OPTION, pOption: *mut c_void) -> c_int,
}

#[rustfmt::skip]
#[allow(clippy::too_many_arguments)]
#[allow(clippy::missing_safety_doc)]
#[allow(non_snake_case)]
#[allow(unused)]
impl EncoderRawAPI {
    fn new() -> Result<Self, Error> {
        unsafe {
            let mut encoder_ptr = null::<ISVCEncoderVtbl>() as *mut *const ISVCEncoderVtbl;

            WelsCreateSVCEncoder(&mut encoder_ptr as *mut *mut *const ISVCEncoderVtbl).ok()?;

            let e = || {
                Error::msg("VTable missing function.")
            };

            Ok(Self {
                encoder_ptr,
                initialize: (*(*encoder_ptr)).Initialize.ok_or_else(e)?,
                initialize_ext: (*(*encoder_ptr)).InitializeExt.ok_or_else(e)?,
                get_default_params: (*(*encoder_ptr)).GetDefaultParams.ok_or_else(e)?,
                uninitialize: (*(*encoder_ptr)).Uninitialize.ok_or_else(e)?,
                encode_frame: (*(*encoder_ptr)).EncodeFrame.ok_or_else(e)?,
                encode_parameter_sets: (*(*encoder_ptr)).EncodeParameterSets.ok_or_else(e)?,
                force_intra_frame: (*(*encoder_ptr)).ForceIntraFrame.ok_or_else(e)?,
                set_option: (*(*encoder_ptr)).SetOption.ok_or_else(e)?,
                get_option: (*(*encoder_ptr)).GetOption.ok_or_else(e)?,
            })
        }
    }

    // Exposing these will probably do more harm than good.
    unsafe fn uninitialize(&self) -> c_int { (self.uninitialize)(self.encoder_ptr) }
    unsafe fn initialize(&self, pParam: *const SEncParamBase) -> c_int { (self.initialize)(self.encoder_ptr, pParam) }
    unsafe fn initialize_ext(&self, pParam: *const SEncParamExt) -> c_int { (self.initialize_ext)(self.encoder_ptr, pParam) }

    pub unsafe fn get_default_params(&self, pParam: *mut SEncParamExt) -> c_int { (self.get_default_params)(self.encoder_ptr, pParam) }
    pub unsafe fn encode_frame(&self, kpSrcPic: *const SSourcePicture, pBsInfo: *mut SFrameBSInfo) -> c_int { (self.encode_frame)(self.encoder_ptr, kpSrcPic, pBsInfo) }
    pub unsafe fn encode_parameter_sets(&self, pBsInfo: *mut SFrameBSInfo) -> c_int { (self.encode_parameter_sets)(self.encoder_ptr, pBsInfo) }
    pub unsafe fn force_intra_frame(&self, bIDR: bool) -> c_int { (self.force_intra_frame)(self.encoder_ptr, bIDR) }
    pub unsafe fn set_option(&self, eOptionId: ENCODER_OPTION, pOption: *mut c_void) -> c_int { (self.set_option)(self.encoder_ptr, eOptionId, pOption) }
    pub unsafe fn get_option(&self, eOptionId: ENCODER_OPTION, pOption: *mut c_void) -> c_int { (self.get_option)(self.encoder_ptr, eOptionId, pOption) }
}

impl Drop for EncoderRawAPI {
    fn drop(&mut self) {
        // Safe because when we drop the pointer must have been initialized, and we aren't clone.
        unsafe {
            WelsDestroySVCEncoder(self.encoder_ptr);
        }
    }
}

unsafe impl Send for EncoderRawAPI {}
unsafe impl Sync for EncoderRawAPI {}

/// Specifies the mode used by the encoder to control the rate.
#[derive(Copy, Clone, Debug)]
pub enum RateControlMode {
    /// Quality mode.
    Quality,
    /// Bitrate mode.
    Bitrate,
    /// No bitrate control, only using buffer status, adjust the video quality.
    Bufferbased,
    /// Rate control based timestamp.
    Timestamp,
    /// This is in-building RC MODE, WILL BE DELETED after algorithm tuning!
    BitrateModePostSkip,
    /// Rate control off mode.
    Off,
}

impl Default for RateControlMode {
    fn default() -> Self {
        Self::Quality
    }
}

impl RateControlMode {
    fn to_c(self) -> RC_MODES {
        match self {
            RateControlMode::Quality => openh264_sys2::RC_QUALITY_MODE,
            RateControlMode::Bitrate => openh264_sys2::RC_BITRATE_MODE,
            RateControlMode::Bufferbased => openh264_sys2::RC_BUFFERBASED_MODE,
            RateControlMode::Timestamp => openh264_sys2::RC_TIMESTAMP_MODE,
            RateControlMode::BitrateModePostSkip => openh264_sys2::RC_BITRATE_MODE_POST_SKIP,
            RateControlMode::Off => openh264_sys2::RC_OFF_MODE,
        }
    }
}

/// Configuration for the [`Encoder`].
///
/// Setting missing? Please file a PR!
#[derive(Default, Copy, Clone, Debug)]
pub struct EncoderConfig {
    width: u32,
    height: u32,
    enable_skip_frame: bool,
    target_bitrate: u32,
    enable_denoise: bool,
    debug: i32,
    data_format: EVideoFormatType,
    max_frame_rate: f32,
    rate_control_mode: RateControlMode,
}

impl EncoderConfig {
    /// Creates a new default encoder config.
    pub fn new(width: u32, height: u32) -> Self {
        Self {
            width,
            height,
            enable_skip_frame: true,
            target_bitrate: 120_000,
            enable_denoise: false,
            debug: 0,
            data_format: videoFormatI420,
            max_frame_rate: 0.0,
            rate_control_mode: Default::default(),
        }
    }

    /// Sets the requested bit rate in bits per second.
    pub fn set_bitrate_bps(mut self, bps: u32) -> Self {
        self.target_bitrate = bps;
        self
    }

    /// Enables detailed console logging inside OpenH264.
    pub fn debug(mut self, value: bool) -> Self {
        self.debug = if value { WELS_LOG_DETAIL } else { WELS_LOG_QUIET };
        self
    }

    /// Set whether frames can be skipped to meet desired rate control target.
    pub fn enable_skip_frame(mut self, value: bool) -> Self {
        self.enable_skip_frame = value;
        self
    }

    /// Sets the requested maximum frame rate in Hz.
    pub fn max_frame_rate(mut self, value: f32) -> Self {
        self.max_frame_rate = value;
        self
    }

    /// Sets the requested rate control mode.
    pub fn rate_control_mode(mut self, value: RateControlMode) -> Self {
        self.rate_control_mode = value;
        self
    }
}

/// An [OpenH264](https://github.com/cisco/openh264) encoder.
pub struct Encoder {
    params: SEncParamExt,
    raw_api: EncoderRawAPI,
    bit_stream_info: SFrameBSInfo,
}

unsafe impl Send for Encoder {}
unsafe impl Sync for Encoder {}

impl Encoder {
    /// Create an encoder with the provided configuration.
    pub fn with_config(mut config: EncoderConfig) -> Result<Self, Error> {
        let raw_api = EncoderRawAPI::new()?;
        let mut params = SEncParamExt::default();

        #[rustfmt::skip]
        unsafe {
            raw_api.get_default_params(&mut params).ok()?;
            params.iPicWidth = config.width as c_int;
            params.iPicHeight = config.height as c_int;
            params.iRCMode = config.rate_control_mode.to_c();
            params.bEnableFrameSkip = config.enable_skip_frame;
            params.iTargetBitrate = config.target_bitrate as c_int;
            params.bEnableDenoise = config.enable_denoise;
            params.fMaxFrameRate = config.max_frame_rate;
            raw_api.initialize_ext(&params).ok()?;

            raw_api.set_option(ENCODER_OPTION_TRACE_LEVEL, addr_of_mut!(config.debug).cast()).ok()?;
            raw_api.set_option(ENCODER_OPTION_DATAFORMAT, addr_of_mut!(config.data_format).cast()).ok()?;
        };

        Ok(Self {
            params,
            raw_api,
            bit_stream_info: Default::default(),
        })
    }

    /// Encodes a YUV source and returns the encoded bitstream.
    ///
    /// The returned bitstream consists of one or more NAL units or packets. The first packets contain
    /// initialization information. Subsequent packages then contain, amongst others, keyframes
    /// ("I frames") or delta frames. The interval at which they are produced depends on the encoder settings.
    ///
    /// # Panics
    ///
    /// Panics if the source image dimension don't match the configured format.
    pub fn encode<T: YUVSource>(&mut self, yuv_source: &T) -> Result<EncodedBitStream<'_>, Error> {
        self.encode_at(yuv_source, Timestamp::ZERO)
    }

    /// Encodes a YUV source and returns the encoded bitstream.
    ///
    /// The returned bitstream consists of one or more NAL units or packets. The first packets contain
    /// initialization information. Subsequent packages then contain, amongst others, keyframes
    /// ("I frames") or delta frames. The interval at which they are produced depends on the encoder settings.
    ///
    /// # Panics
    ///
    /// Panics if the source image dimension don't match the configured format.
    ///
    /// Panics if the provided timestamp as milliseconds is out of range of i64.
    pub fn encode_at<T: YUVSource>(&mut self, yuv_source: &T, timestamp: Timestamp) -> Result<EncodedBitStream<'_>, Error> {
        assert_eq!(yuv_source.width(), self.params.iPicWidth);
        assert_eq!(yuv_source.height(), self.params.iPicHeight);

        // Converting *const u8 to *mut u8 should be fine because the encoder _should_
        // only read these arrays (TODO: needs verification).
        let source = SSourcePicture {
            iColorFormat: videoFormatI420,
            iStride: [yuv_source.y_stride(), yuv_source.u_stride(), yuv_source.v_stride(), 0],
            pData: [
                yuv_source.y().as_ptr() as *mut c_uchar,
                yuv_source.u().as_ptr() as *mut c_uchar,
                yuv_source.v().as_ptr() as *mut c_uchar,
                null_mut(),
            ],
            iPicWidth: self.params.iPicWidth,
            iPicHeight: self.params.iPicHeight,
            uiTimeStamp: timestamp.as_native(),
        };

        unsafe {
            self.raw_api.encode_frame(&source, &mut self.bit_stream_info).ok()?;

            Ok(EncodedBitStream {
                bit_stream_info: &self.bit_stream_info,
            })
        }
    }

    /// Obtain the raw API for advanced use cases.
    ///
    /// When resorting to this call, please consider filing an issue / PR.
    ///
    /// # Safety
    ///
    /// You must not set parameters the encoder relies on, we recommend checking the source.
    pub unsafe fn raw_api(&mut self) -> &mut EncoderRawAPI {
        &mut self.raw_api
    }
}

impl Drop for Encoder {
    fn drop(&mut self) {
        // Safe because when we drop the pointer must have been initialized.
        unsafe {
            self.raw_api.uninitialize();
        }
    }
}

/// Bitstream output resulting from an [encode()](Encoder::encode) operation.
pub struct EncodedBitStream<'a> {
    /// Holds the bitstream info just encoded.
    bit_stream_info: &'a SFrameBSInfo,
}

impl<'a> EncodedBitStream<'a> {
    /// Raw bitstream info returned by the encoder.
    pub fn raw_info(&self) -> &'a SFrameBSInfo {
        self.bit_stream_info
    }

    /// Frame type of the encoded packet.
    pub fn frame_type(&self) -> FrameType {
        FrameType::from_c_int(self.bit_stream_info.eFrameType)
    }

    /// Number of layers in the encoded packet.
    pub fn num_layers(&self) -> usize {
        self.bit_stream_info.iLayerNum as usize
    }

    /// Returns ith layer of this bitstream.
    pub fn layer(&self, i: usize) -> Option<Layer<'a>> {
        if i < self.num_layers() {
            Some(Layer {
                layer_info: &self.bit_stream_info.sLayerInfo[i],
            })
        } else {
            None
        }
    }

    /// Writes the current bitstream into the given Vec.
    pub fn write_vec(&self, dst: &mut Vec<u8>) {
        for l in 0..self.num_layers() {
            let layer = self.layer(l).unwrap();

            for n in 0..layer.nal_count() {
                let nal = layer.nal_unit(n).unwrap();

                dst.extend_from_slice(nal)
            }
        }
    }

    /// Writes the current bitstream into the given Writer.
    pub fn write<T: std::io::Write>(&self, writer: &mut T) -> Result<(), Error> {
        for l in 0..self.num_layers() {
            let layer = self.layer(l).unwrap();

            for n in 0..layer.nal_count() {
                let nal = layer.nal_unit(n).unwrap();

                match writer.write(nal) {
                    Ok(num) if num < nal.len() => {
                        return Err(Error::msg(&format!("only wrote {} out of {} bytes", num, nal.len())));
                    }
                    Err(e) => {
                        return Err(Error::msg(&format!("failed to write: {}", e)));
                    }
                    _ => {}
                };
            }
        }
        Ok(())
    }

    /// Convenience method returning a Vec containing the encoded bitstream.
    pub fn to_vec(&self) -> Vec<u8> {
        let mut rval = Vec::new();
        self.write_vec(&mut rval);
        rval
    }
}

/// An encoded layer, contains the Network Abstraction Layer inputs.
#[derive(Debug)]
pub struct Layer<'a> {
    /// Native layer info.
    layer_info: &'a SLayerBSInfo,
}

impl<'a> Layer<'a> {
    /// Raw layer info contained in a bitstream.
    pub fn raw_info(&self) -> &'a SLayerBSInfo {
        self.layer_info
    }

    /// NAL count of this layer.
    pub fn nal_count(&self) -> usize {
        self.layer_info.iNalCount as usize
    }

    /// Returns NAL unit data for the ith element.
    pub fn nal_unit(&self, i: usize) -> Option<&[u8]> {
        if i < self.nal_count() {
            let mut offset = 0;

            let slice = unsafe {
                // Fast forward through all NALs we didn't request
                // TODO: We can probably do this math a bit more efficiently, not counting up all the time.
                // pNalLengthInByte is a c_int C array containing the nal unit sizes
                for nal_idx in 0..i {
                    let size = *self.layer_info.pNalLengthInByte.add(nal_idx) as usize;
                    offset += size;
                }

                let size = *self.layer_info.pNalLengthInByte.add(i) as usize;
                std::slice::from_raw_parts(self.layer_info.pBsBuf.add(offset), size)
            };

            Some(slice)
        } else {
            None
        }
    }

    /// If this is a video layer or not.
    pub fn is_video(&self) -> bool {
        self.layer_info.uiLayerType == VIDEO_CODING_LAYER as c_uchar
    }
}

/// Frame type returned by the encoder.
///
/// The variant documentation was directly taken from OpenH264 project.
#[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)]
pub enum FrameType {
    /// Encoder not ready or parameters are invalidate.
    Invalid,
    /// IDR frame in H.264
    IDR,
    /// I frame type
    I,
    /// P frame type
    P,
    /// Skip the frame based encoder kernel"
    Skip,
    /// A frame where I and P slices are mixing, not supported yet.
    IPMixed,
}

impl FrameType {
    fn from_c_int(native: std::os::raw::c_int) -> Self {
        use openh264_sys2::{videoFrameTypeI, videoFrameTypeIDR, videoFrameTypeIPMixed, videoFrameTypeP, videoFrameTypeSkip};

        #[allow(non_upper_case_globals)]
        match native {
            videoFrameTypeIDR => Self::IDR,
            videoFrameTypeI => Self::I,
            videoFrameTypeP => Self::P,
            videoFrameTypeSkip => Self::Skip,
            videoFrameTypeIPMixed => Self::IPMixed,
            _ => Self::Invalid,
        }
    }
}