Skip to content

Commit 26abf92

Browse files
committed
Fix bug in ValidateInput
When reading from a stream reader that consumes input as it reads, validation would fail because it would attempt to read the stream twice.
1 parent fc65111 commit 26abf92

File tree

2 files changed

+62
-15
lines changed

2 files changed

+62
-15
lines changed

etree.go

+19-3
Original file line numberDiff line numberDiff line change
@@ -358,9 +358,14 @@ func (d *Document) SetRoot(e *Element) {
358358
// returns the number of bytes read and any error encountered.
359359
func (d *Document) ReadFrom(r io.Reader) (n int64, err error) {
360360
if d.ReadSettings.ValidateInput {
361-
if err := validateXML(r, d.ReadSettings); err != nil {
361+
b, err := io.ReadAll(r)
362+
if err != nil {
362363
return 0, err
363364
}
365+
if err := validateXML(bytes.NewReader(b), d.ReadSettings); err != nil {
366+
return 0, err
367+
}
368+
r = bytes.NewReader(b)
364369
}
365370
return d.Element.readFrom(r, d.ReadSettings)
366371
}
@@ -373,19 +378,30 @@ func (d *Document) ReadFromFile(filepath string) error {
373378
return err
374379
}
375380
defer f.Close()
381+
376382
_, err = d.ReadFrom(f)
377383
return err
378384
}
379385

380386
// ReadFromBytes reads XML from the byte slice 'b' into the this document.
381387
func (d *Document) ReadFromBytes(b []byte) error {
382-
_, err := d.ReadFrom(bytes.NewReader(b))
388+
if d.ReadSettings.ValidateInput {
389+
if err := validateXML(bytes.NewReader(b), d.ReadSettings); err != nil {
390+
return err
391+
}
392+
}
393+
_, err := d.Element.readFrom(bytes.NewReader(b), d.ReadSettings)
383394
return err
384395
}
385396

386397
// ReadFromString reads XML from the string 's' into this document.
387398
func (d *Document) ReadFromString(s string) error {
388-
_, err := d.ReadFrom(strings.NewReader(s))
399+
if d.ReadSettings.ValidateInput {
400+
if err := validateXML(strings.NewReader(s), d.ReadSettings); err != nil {
401+
return err
402+
}
403+
}
404+
_, err := d.Element.readFrom(strings.NewReader(s), d.ReadSettings)
389405
return err
390406
}
391407

etree_test.go

+43-12
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@ package etree
77
import (
88
"bytes"
99
"encoding/xml"
10+
"errors"
1011
"io"
12+
"io/fs"
1113
"math/rand"
14+
"os"
15+
"path"
1216
"strings"
1317
"testing"
1418
)
@@ -1540,19 +1544,46 @@ func TestValidateInput(t *testing.T) {
15401544
{`<root><child>x</child></root1>`, `XML syntax error on line 1: element <root> closed by </root1>`},
15411545
}
15421546

1543-
for i, test := range tests {
1544-
doc := NewDocument()
1545-
doc.ReadSettings.ValidateInput = true
1546-
err := doc.ReadFromString(test.s)
1547-
if err == nil {
1548-
if test.err != "" {
1549-
t.Errorf("etree: test #%d:\nExpected error:\n %s\nReceived error:\n nil", i, test.err)
1550-
}
1551-
} else {
1552-
te := err.Error()
1553-
if te != test.err {
1554-
t.Errorf("etree: test #%d:\nExpected error;\n %s\nReceived error:\n %s", i, test.err, te)
1547+
type readFunc func(doc *Document, s string) error
1548+
runTests := func(t *testing.T, read readFunc) {
1549+
for i, test := range tests {
1550+
doc := NewDocument()
1551+
doc.ReadSettings.ValidateInput = true
1552+
err := read(doc, test.s)
1553+
if err == nil {
1554+
if test.err != "" {
1555+
t.Errorf("etree: test #%d:\nExpected error:\n %s\nReceived error:\n nil", i, test.err)
1556+
}
1557+
root := doc.Root()
1558+
if root == nil || root.Tag != "root" {
1559+
t.Errorf("etree: test #%d: failed to read document after input validation", i)
1560+
}
1561+
} else {
1562+
te := err.Error()
1563+
if te != test.err {
1564+
t.Errorf("etree: test #%d:\nExpected error;\n %s\nReceived error:\n %s", i, test.err, te)
1565+
}
15551566
}
15561567
}
15571568
}
1569+
1570+
readFromString := func(doc *Document, s string) error {
1571+
return doc.ReadFromString(s)
1572+
}
1573+
t.Run("ReadFromString", func(t *testing.T) { runTests(t, readFromString) })
1574+
1575+
readFromBytes := func(doc *Document, s string) error {
1576+
return doc.ReadFromBytes([]byte(s))
1577+
}
1578+
t.Run("ReadFromBytes", func(t *testing.T) { runTests(t, readFromBytes) })
1579+
1580+
readFromFile := func(doc *Document, s string) error {
1581+
pathtmp := path.Join(t.TempDir(), "etree-test")
1582+
err := os.WriteFile(pathtmp, []byte(s), fs.ModePerm)
1583+
if err != nil {
1584+
return errors.New("unable to write tmp file for input validation")
1585+
}
1586+
return doc.ReadFromFile(pathtmp)
1587+
}
1588+
t.Run("ReadFromFile", func(t *testing.T) { runTests(t, readFromFile) })
15581589
}

0 commit comments

Comments
 (0)