1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
//! Contains functions for performing XML special characters escaping.

use std::borrow::Cow;
use std::fmt::{Display, Formatter, Result};
use std::marker::PhantomData;

pub(crate) trait Escapes {
    fn escape(c: u8) -> Option<&'static str>;

    fn byte_needs_escaping(c: u8) -> bool {
        Self::escape(c).is_some()
    }

    fn str_needs_escaping(s: &str) -> bool {
        s.bytes().any(|c| Self::escape(c).is_some())
    }
}

pub(crate) struct Escaped<'a, E: Escapes> {
    _escape_phantom: PhantomData<E>,
    to_escape: &'a str,
}

impl<'a, E: Escapes> Escaped<'a, E> {
    pub fn new(s: &'a str) -> Self {
        Escaped {
            _escape_phantom: PhantomData,
            to_escape: s,
        }
    }
}

impl<'a, E: Escapes> Display for Escaped<'a, E> {
    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
        let mut total_remaining = self.to_escape;

        // find the next occurence
        while let Some(n) = total_remaining.bytes().position(E::byte_needs_escaping) {
            let (start, remaining) = total_remaining.split_at(n);

            f.write_str(start)?;

            // unwrap is safe because we checked is_some for position n earlier
            let next_byte = remaining.bytes().next().unwrap();
            let replacement = E::escape(next_byte).unwrap_or("unexpected token");
            f.write_str(replacement)?;

            total_remaining = &remaining[1..];
        }

        f.write_str(total_remaining)
    }
}

fn escape_str<E: Escapes>(s: &str) -> Cow<'_, str> {
    if E::str_needs_escaping(s) {
        Cow::Owned(Escaped::<E>::new(s).to_string())
    } else {
        Cow::Borrowed(s)
    }
}

macro_rules! escapes {
    {
        $name: ident,
        $($k: expr => $v: expr),* $(,)?
    } => {
        pub(crate) struct $name;

        impl Escapes for $name {
            fn escape(c: u8) -> Option<&'static str> {
                match c {
                    $( $k => Some($v),)*
                    _ => None
                }
            }
        }
    };
}

escapes!(
    AttributeEscapes,
    b'<'  => "&lt;",
    b'>'  => "&gt;",
    b'"'  => "&quot;",
    b'\'' => "&apos;",
    b'&'  => "&amp;",
    b'\n' => "&#xA;",
    b'\r' => "&#xD;",
);

escapes!(
    PcDataEscapes,
    b'<' => "&lt;",
    b'>' => "&gt;",
    b'&' => "&amp;",
);

/// Performs escaping of common XML characters inside an attribute value.
///
/// This function replaces several important markup characters with their
/// entity equivalents:
///
/// * `<` → `&lt;`
/// * `>` → `&gt;`
/// * `"` → `&quot;`
/// * `'` → `&apos;`
/// * `&` → `&amp;`
///
/// The following characters are escaped so that attributes are printed on
/// a single line:
/// * `\n` → `&#xA;`
/// * `\r` → `&#xD;`
///
/// The resulting string is safe to use inside XML attribute values or in PCDATA sections.
///
/// Does not perform allocations if the given string does not contain escapable characters.
#[inline]
#[must_use]
pub fn escape_str_attribute(s: &str) -> Cow<'_, str> {
    escape_str::<AttributeEscapes>(s)
}

/// Performs escaping of common XML characters inside PCDATA.
///
/// This function replaces several important markup characters with their
/// entity equivalents:
///
/// * `<` → `&lt;`
/// * `&` → `&amp;`
///
/// The resulting string is safe to use inside PCDATA sections but NOT inside attribute values.
///
/// Does not perform allocations if the given string does not contain escapable characters.
#[inline]
#[must_use]
pub fn escape_str_pcdata(s: &str) -> Cow<'_, str> {
    escape_str::<PcDataEscapes>(s)
}

#[cfg(test)]
mod tests {
    use super::{escape_str_attribute, escape_str_pcdata};

    #[test]
    fn test_escape_str_attribute() {
        assert_eq!(escape_str_attribute("<>'\"&\n\r"), "&lt;&gt;&apos;&quot;&amp;&#xA;&#xD;");
        assert_eq!(escape_str_attribute("no_escapes"), "no_escapes");
    }

    #[test]
    fn test_escape_str_pcdata() {
        assert_eq!(escape_str_pcdata("<>&"), "&lt;&gt;&amp;");
        assert_eq!(escape_str_pcdata("no_escapes"), "no_escapes");
    }

    #[test]
    fn test_escape_multibyte_code_points() {
        assert_eq!(escape_str_attribute("☃<"), "☃&lt;");
        assert_eq!(escape_str_pcdata("☃<"), "☃&lt;");
    }
}