1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
//! Image-related utilities.

use smallvec::{smallvec, SmallVec};

use crate::datatypes::{TensorData, TensorDimension};

/// Error returned when trying to interpret a tensor as an image.
#[derive(thiserror::Error, Clone, Debug)]
pub enum ImageConstructionError<T: TryInto<TensorData>>
where
    T::Error: std::error::Error,
{
    /// Could not convert source to [`TensorData`].
    #[error("Could not convert source to TensorData: {0}")]
    TensorDataConversion(T::Error),

    /// The tensor did not have the right shape for an image (e.g. had too many dimensions).
    #[error("Could not create Image from TensorData with shape {0:?}")]
    BadImageShape(Vec<TensorDimension>),
}

/// Returns the indices of an appropriate set of dimensions.
///
/// Ignores leading and trailing 1-sized dimensions.
///
/// For instance: `[1, 640, 480, 3, 1]` would return `[1, 2, 3]`,
/// the indices of the `[640, 480, 3]` dimensions.
pub fn find_non_empty_dim_indices(shape: &[TensorDimension]) -> SmallVec<[usize; 4]> {
    match shape.len() {
        0 => return smallvec![],
        1 => return smallvec![0],
        2 => return smallvec![0, 1],
        _ => {}
    }

    // Find a range of non-unit dimensions.
    // [1, 1, 1, 640, 480, 3, 1, 1, 1]
    //           ^---------^   goal range

    let mut non_unit_indices =
        shape
            .iter()
            .enumerate()
            .filter_map(|(ind, dim)| if dim.size != 1 { Some(ind) } else { None });

    // 0 is always a valid index.
    let mut min = non_unit_indices.next().unwrap_or(0);
    let mut max = non_unit_indices.last().unwrap_or(min);

    // Note, these are inclusive ranges.

    // First, empty inner dimensions are more likely to be intentional than empty outer dimensions.
    // Grow to a min-size of 2.
    // (1x1x3x1) -> 3x1 mono rather than 1x1x3 RGB
    while max == min && max + 1 < shape.len() {
        max += 1;
    }

    // Next, consider empty outer dimensions if we still need them.
    // Grow up to 3 if the inner dimension is already 3 or 4 (Color Images)
    // Otherwise, only grow up to 2.
    // (1x1x3) -> 1x1x3 rgb rather than 1x3 mono
    let target_len = match shape[max].size {
        3 | 4 => 3,
        _ => 2,
    };

    while max - min + 1 < target_len && 0 < min {
        min -= 1;
    }

    (min..=max).collect()
}

#[test]
fn test_find_non_empty_dim_indices() {
    fn expect(shape: &[u64], expected: &[usize]) {
        let dim: Vec<_> = shape
            .iter()
            .map(|s| TensorDimension {
                size: *s,
                name: None,
            })
            .collect();
        let got = find_non_empty_dim_indices(&dim);
        assert!(
            got.as_slice() == expected,
            "Input: {shape:?}, got {got:?}, expected {expected:?}"
        );
    }

    expect(&[], &[]);
    expect(&[0], &[0]);
    expect(&[1], &[0]);
    expect(&[100], &[0]);

    expect(&[640, 480], &[0, 1]);
    expect(&[640, 480, 1], &[0, 1]);
    expect(&[640, 480, 1, 1], &[0, 1]);
    expect(&[640, 480, 3], &[0, 1, 2]);
    expect(&[1, 640, 480], &[1, 2]);
    expect(&[1, 640, 480, 3, 1], &[1, 2, 3]);
    expect(&[1, 3, 640, 480, 1], &[1, 2, 3]);
    expect(&[1, 1, 640, 480], &[2, 3]);
    expect(&[1, 1, 640, 480, 1, 1], &[2, 3]);

    expect(&[1, 1, 3], &[0, 1, 2]);
    expect(&[1, 1, 3, 1], &[2, 3]);
}