identity_core/common/
data_url.rs1use std::fmt::Display;
5use std::str::FromStr;
6
7use serde::Serialize;
8
9use crate::common::Url;
10
11const DEFAULT_MIME_TYPE: &str = "text/plain";
12
13#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
15pub struct DataUrl {
16 serialized: Box<str>,
17 start_of_data: u32,
18 base64: bool,
19}
20
21impl AsRef<str> for DataUrl {
22 fn as_ref(&self) -> &str {
23 &self.serialized
24 }
25}
26
27impl DataUrl {
28 pub const fn as_str(&self) -> &str {
30 &self.serialized
31 }
32
33 pub fn parse(input: &str) -> Result<Self, InvalidDataUrl> {
48 use nom::combinator::all_consuming;
49 use nom::Parser as _;
50
51 let (_, data_url) = all_consuming(parsers::data_url)
52 .parse(input)
53 .map_err(|_| InvalidDataUrl {})?;
54 Ok(data_url)
55 }
56
57 pub const fn is_base64(&self) -> bool {
59 self.base64
60 }
61
62 pub fn encoded_data(&self) -> &str {
64 let idx = self.start_of_data as usize;
65 &self.as_str()[idx..]
66 }
67
68 pub fn media_type(&self) -> &str {
76 let start = "data:".len();
77 let end = self.start_of_data as usize
78 - 1 - self.base64 as usize * ";base64".len(); let mime = &self.serialized[start..end];
82 if mime.is_empty() {
83 DEFAULT_MIME_TYPE
84 } else {
85 mime
86 }
87 }
88}
89
90impl Display for DataUrl {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 f.write_str(&self.serialized)
93 }
94}
95
96impl FromStr for DataUrl {
97 type Err = InvalidDataUrl;
98
99 fn from_str(s: &str) -> Result<Self, Self::Err> {
100 DataUrl::parse(s)
101 }
102}
103
104impl From<DataUrl> for Url {
105 fn from(data_url: DataUrl) -> Self {
106 Url::parse(data_url.as_str()).expect("DataUrl is always a valid Url")
107 }
108}
109
110impl Serialize for DataUrl {
111 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
112 where
113 S: serde::Serializer,
114 {
115 serializer.serialize_str(&self.serialized)
116 }
117}
118
119impl<'de> serde::Deserialize<'de> for DataUrl {
120 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
121 where
122 D: serde::Deserializer<'de>,
123 {
124 use serde::de::Error;
125
126 let str = <&str>::deserialize(deserializer)?;
127 DataUrl::parse(str).map_err(|_| Error::custom("invalid data URL"))
128 }
129}
130
131#[derive(Debug, Clone)]
133#[non_exhaustive]
134pub struct InvalidDataUrl {}
135
136impl Display for InvalidDataUrl {
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 f.write_str("invalid data URL")
139 }
140}
141
142impl std::error::Error for InvalidDataUrl {}
143mod parsers {
144 use nom::branch::alt;
145 use nom::bytes::complete::tag;
146 use nom::bytes::complete::take_while1;
147 use nom::bytes::complete::take_while_m_n;
148 #[cfg(test)]
149 use nom::combinator::all_consuming;
150 use nom::combinator::opt;
151 use nom::combinator::recognize;
152 use nom::multi::many1_count;
153 use nom::sequence::preceded;
154 use nom::sequence::separated_pair;
155 use nom::IResult;
156 use nom::Parser;
157
158 use super::DataUrl;
159
160 pub(super) fn data_url(input: &str) -> IResult<&str, DataUrl> {
161 let (rem, (_type, base64, data)) = preceded(
162 tag("data:"),
163 (
164 opt(mediatype),
165 opt(tag(";base64")).map(|opt| opt.is_some()),
166 preceded(tag(","), uri_char1),
167 ),
168 )
169 .parse(input)?;
170
171 let consumed = input.len() - rem.len();
172 let serialized = input[..consumed].to_owned().into_boxed_str();
173 let start_of_data = (consumed - data.len()) as u32;
174
175 Ok((
176 rem,
177 DataUrl {
178 serialized,
179 start_of_data,
180 base64,
181 },
182 ))
183 }
184
185 fn mediatype(input: &str) -> IResult<&str, &str> {
186 let type_ = separated_pair(media_char1, tag("/"), media_char1);
187 let parameters = many1_count(preceded(tag(";"), separated_pair(media_char1, tag("="), media_char1)));
188
189 recognize((type_, opt(parameters))).parse(input)
190 }
191
192 fn uri_char1(input: &str) -> IResult<&str, &str> {
193 let reserved = take_while1(|c: char| ";/?:@&=+$,".contains(c));
194 let unreserved = take_while1(|c: char| "-_.!~*'(|)".contains(c) || c.is_ascii_alphanumeric());
195 let escaped = recognize(percent_escaped);
196
197 recognize(many1_count(alt((reserved, unreserved, escaped)))).parse(input)
198 }
199
200 fn media_char1(input: &str) -> IResult<&str, &str> {
201 take_while1(|c: char| c.is_ascii_alphanumeric() || "-_.+".contains(c))(input)
202 }
203
204 fn percent_escaped(input: &str) -> IResult<&str, u8> {
205 preceded(tag("%"), take_while_m_n(2, 2, |c: char| c.is_ascii_hexdigit()))
206 .map_res(|hex_byte| u8::from_str_radix(hex_byte, 16))
207 .parse(input)
208 }
209
210 #[cfg(test)]
211 #[test]
212 fn mediatype_parser() {
213 all_consuming(mediatype).parse("text/plain").unwrap();
214 all_consuming(mediatype).parse("application/vc+jwt").unwrap();
215 all_consuming(mediatype).parse("video/mp4").unwrap();
216 all_consuming(mediatype).parse("text/plain;charset=us-ascii").unwrap();
217 }
218
219 #[cfg(test)]
220 #[test]
221 fn data_url_parser() {
222 all_consuming(data_url).parse("data:text/plain,hello").unwrap();
223 all_consuming(data_url)
224 .parse("data:text/plain;charset=us-ascii,hello%20world")
225 .unwrap();
226 all_consuming(data_url).parse("data:,hello%20world").unwrap();
227 all_consuming(data_url).parse("data:application/vc+jwt,ey").unwrap();
228 let (_, data_url) = all_consuming(data_url)
229 .parse("data:application/vc+jwt;base64,ey")
230 .unwrap();
231 assert!(data_url.is_base64());
232 }
233}