Skip to content

Commit efadaf3

Browse files
Complete rewrite of StanzaFilter, more tests, supports CDATA and more
1 parent 3792d22 commit efadaf3

File tree

1 file changed

+146
-41
lines changed

1 file changed

+146
-41
lines changed

src/stanzafilter.rs

Lines changed: 146 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,28 @@
11
use anyhow::{bail, Result};
22

3+
use crate::stanzafilter::StanzaState::*;
34
use crate::to_str;
45

6+
#[derive(Debug)]
7+
enum StanzaState {
8+
OutsideStanza,
9+
StanzaFirstChar,
10+
InsideTagFirstChar,
11+
InsideTag,
12+
BetweenTags,
13+
ExclamationTag(usize),
14+
InsideCDATA,
15+
QuestionTag(usize),
16+
InsideXmlTag,
17+
EndStream,
18+
}
19+
520
pub struct StanzaFilter {
621
buf_size: usize,
722
pub buf: Vec<u8>,
823
cnt: usize,
924
tag_cnt: usize,
10-
last_char_was_lt: bool,
11-
last_char_was_backslash: bool,
25+
state: StanzaState,
1226
}
1327

1428
impl StanzaFilter {
@@ -18,8 +32,7 @@ impl StanzaFilter {
1832
buf: vec![0u8; buf_size],
1933
cnt: 0,
2034
tag_cnt: 0,
21-
last_char_was_lt: false,
22-
last_char_was_backslash: false,
35+
state: OutsideStanza,
2336
}
2437
}
2538

@@ -37,49 +50,122 @@ impl StanzaFilter {
3750
}
3851

3952
pub fn process_next_byte_idx(&mut self) -> Result<Option<usize>> {
40-
//println!("n: {}", n);
4153
let b = self.buf[self.cnt];
42-
if b == b'<' {
43-
self.tag_cnt += 1;
44-
self.last_char_was_lt = true;
45-
} else {
46-
if b == b'/' {
47-
// if last_char_was_lt but tag_cnt < 2, should only be </stream:stream>
48-
if self.last_char_was_lt && self.tag_cnt >= 2 {
49-
// non-self-closing tag
50-
self.tag_cnt -= 2;
54+
//print!("b: '{}', cnt: {}, tag_cnt: {}, state: {:?}; ", b as char, self.cnt, self.tag_cnt, self.state);
55+
match self.state {
56+
OutsideStanza => {
57+
if b == b'<' {
58+
self.tag_cnt += 1;
59+
self.state = StanzaFirstChar;
60+
} else {
61+
// outside of stanzas, let's ignore all characters except <
62+
// prosody does this, and since things do whitespace pings, it's good
63+
return Ok(None);
64+
}
65+
}
66+
BetweenTags => {
67+
if b == b'<' {
68+
self.tag_cnt += 1;
69+
self.state = InsideTagFirstChar;
5170
}
52-
self.last_char_was_backslash = true;
53-
} else {
71+
}
72+
StanzaFirstChar => match b {
73+
b'/' => self.state = EndStream,
74+
b'!' => bail!("illegal stanza: {}", to_str(&self.buf[..(self.cnt + 1)])),
75+
b'?' => self.state = QuestionTag(self.cnt + 4), // 4 is length of b"xml "
76+
_ => self.state = InsideTag,
77+
},
78+
InsideTagFirstChar => match b {
79+
b'/' => self.tag_cnt -= 2,
80+
b'!' => self.state = ExclamationTag(self.cnt + 7), // 7 is length of b"[CDATA["
81+
b'?' => bail!("illegal stanza: {}", to_str(&self.buf[..(self.cnt + 1)])),
82+
_ => self.state = InsideTag,
83+
},
84+
InsideTag => {
5485
if b == b'>' {
55-
if self.last_char_was_backslash {
86+
if self.buf[self.cnt - 1] == b'/' {
87+
// state can't be InsideTag unless we are on at least the second character, so can't go out of range
5688
// self-closing tag
5789
self.tag_cnt -= 1;
5890
}
59-
// now special case some tags we want to send stand-alone:
60-
if self.tag_cnt == 1 && self.cnt >= 15 && (b"<?xml" == &self.buf[0..5] || b"<stream:stream" == &self.buf[0..14] || b"</stream:stream" == &self.buf[0..15]) {
61-
self.tag_cnt = 0; // to fall through to next logic
62-
}
6391
if self.tag_cnt == 0 {
64-
//let ret = Ok(Some(&self.buf[0..(self.cnt + 1)]));
65-
let ret = Ok(Some(self.cnt + 1));
66-
self.cnt = 0;
67-
self.last_char_was_backslash = false;
68-
self.last_char_was_lt = false;
69-
return ret;
92+
return self.stanza_end();
93+
}
94+
// now special case <stream:stream ...> which we want to send stand-alone:
95+
if self.tag_cnt == 1 && self.buf.len() >= 15 && b"<stream:stream " == &self.buf[0..15] {
96+
return self.stanza_end();
97+
}
98+
self.state = BetweenTags;
99+
}
100+
}
101+
QuestionTag(idx) => {
102+
if idx == self.cnt {
103+
if self.last_equals(b"xml ")? {
104+
self.state = InsideXmlTag;
105+
} else {
106+
bail!("illegal stanza: {}", to_str(&self.buf[..(self.cnt + 1)]));
107+
}
108+
}
109+
}
110+
InsideXmlTag => {
111+
if b == b'>' {
112+
return self.stanza_end();
113+
}
114+
}
115+
ExclamationTag(idx) => {
116+
if idx == self.cnt {
117+
if self.last_equals(b"[CDATA[")? {
118+
self.state = InsideCDATA;
119+
self.tag_cnt -= 1; // cdata not a tag
120+
} else {
121+
bail!("illegal stanza: {}", to_str(&self.buf[..(self.cnt + 1)]));
122+
}
123+
}
124+
}
125+
InsideCDATA => {
126+
if b == b'>' && self.last_equals(b"]]>")? {
127+
self.state = BetweenTags;
128+
}
129+
}
130+
EndStream => {
131+
if b == b'>' {
132+
if self.last_equals(b"</stream:stream>")? {
133+
return self.stanza_end();
134+
} else {
135+
bail!("illegal stanza: {}", to_str(&self.buf[..(self.cnt + 1)]));
70136
}
71137
}
72-
self.last_char_was_backslash = false;
73138
}
74-
self.last_char_was_lt = false;
75139
}
76-
//println!("b: '{}', cnt: {}, tag_cnt: {}, self.buf.len(): {}", b as char, self.cnt, self.tag_cnt, self.buf.len());
140+
//println!("cnt: {}, tag_cnt: {}, state: {:?}", self.cnt, self.tag_cnt, self.state);
77141
self.cnt += 1;
78142
if self.cnt == self.buf_size {
79143
bail!("stanza too big: {}", to_str(&self.buf));
80144
}
81145
Ok(None)
82146
}
147+
148+
fn stanza_end(&mut self) -> Result<Option<usize>> {
149+
let ret = Ok(Some(self.cnt + 1));
150+
self.tag_cnt = 0;
151+
self.cnt = 0;
152+
self.state = OutsideStanza;
153+
//println!("cnt: {}, tag_cnt: {}, state: {:?}", self.cnt, self.tag_cnt, self.state);
154+
return ret;
155+
}
156+
157+
fn last_equals(&self, needle: &[u8]) -> Result<bool> {
158+
Ok(needle == self.last_num_bytes(needle.len())?)
159+
}
160+
161+
fn last_num_bytes(&self, num: usize) -> Result<&[u8]> {
162+
let num = num - 1;
163+
if num <= self.cnt {
164+
Ok(&self.buf[(self.cnt - num)..(self.cnt + 1)])
165+
} else {
166+
bail!("expected {} bytes only have {} bytes", num, (self.cnt + 1))
167+
}
168+
}
83169
}
84170

85171
// this would be better as an async trait, but that doesn't work yet...
@@ -104,27 +190,46 @@ impl<T: tokio::io::AsyncRead + Unpin> StanzaReader<T> {
104190
#[cfg(test)]
105191
mod tests {
106192
use crate::stanzafilter::*;
107-
use std::borrow::Cow;
108193
use std::io::Cursor;
109194

110195
impl<T: tokio::io::AsyncRead + Unpin> StanzaReader<T> {
111-
async fn next_str<'a>(&'a mut self, filter: &'a mut StanzaFilter) -> Cow<'_, str> {
112-
to_str(self.next(filter).await.expect("was Err").expect("was None"))
196+
async fn to_vec<'a>(&'a mut self, filter: &'a mut StanzaFilter) -> Result<Vec<String>> {
197+
let mut ret = Vec::new();
198+
while let Some(stanza) = self.next(filter).await? {
199+
ret.push(to_str(stanza).to_string());
200+
}
201+
return Ok(ret);
113202
}
114203
}
115204

116205
#[tokio::test]
117206
async fn process_next_byte() -> std::result::Result<(), anyhow::Error> {
118207
let mut filter = StanzaFilter::new(262_144);
119208

120-
let xml_stream = Cursor::new(br###"<a/><b>inside b before c<c>inside c</c></b><d></d>"###);
121-
122-
let mut stanza_reader = StanzaReader(xml_stream);
123-
124-
assert_eq!(stanza_reader.next_str(&mut filter).await, "<a/>");
125-
assert_eq!(stanza_reader.next_str(&mut filter).await, "<b>inside b before c<c>inside c</c></b>");
126-
assert_eq!(stanza_reader.next_str(&mut filter).await, "<d></d>");
127-
assert_eq!(stanza_reader.next(&mut filter).await?, None);
209+
assert_eq!(
210+
StanzaReader(Cursor::new(
211+
br###"
212+
<?xml version='1.0'?>
213+
<stream:stream xmlns='jabber:server' xmlns:stream='http://etherx.jabber.org/streams' xmlns:db='jabber:server:dialback' version='1.0' to='example.org' from='example.com' xml:lang='en'>
214+
<a/><b>inside b before c<c>inside c</c></b></stream:stream>
215+
<q>bla<![CDATA[<this>is</not><xml/>]]>bloo</q>
216+
<d></d><e><![CDATA[what]>]]]]></e></stream:stream>
217+
"###,
218+
))
219+
.to_vec(&mut filter)
220+
.await?,
221+
vec![
222+
"<?xml version='1.0'?>",
223+
"<stream:stream xmlns='jabber:server' xmlns:stream='http://etherx.jabber.org/streams' xmlns:db='jabber:server:dialback' version='1.0' to='example.org' from='example.com' xml:lang='en'>",
224+
"<a/>",
225+
"<b>inside b before c<c>inside c</c></b>",
226+
"</stream:stream>",
227+
"<q>bla<![CDATA[<this>is</not><xml/>]]>bloo</q>",
228+
"<d></d>",
229+
"<e><![CDATA[what]>]]]]></e>",
230+
"</stream:stream>",
231+
]
232+
);
128233

129234
Ok(())
130235
}

0 commit comments

Comments
 (0)