Skip to content

Commit 3926445

Browse files
committed
feat: initial
1 parent bf7de1d commit 3926445

File tree

6 files changed

+288
-0
lines changed

6 files changed

+288
-0
lines changed

.travis.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
language: go
2+
sudo: false
3+
4+
go:
5+
- 1.6.x
6+
- 1.7.x
7+
- 1.8.x
8+
- tip
9+
10+
install:
11+
- go get -u github.com/kardianos/govendor
12+
- go get github.com/campoy/embedmd
13+
- govendor sync
14+
15+
script:
16+
- embedmd -d *.md
17+
- go test -v -covermode=atomic -coverprofile=coverage.out
18+
19+
after_success:
20+
- bash <(curl -s https://codecov.io/bash)
21+
22+
notifications:
23+
webhooks:
24+
urls:
25+
- https://webhooks.gitter.im/e/acc2c57482e94b44f557
26+
on_success: change
27+
on_failure: always
28+
on_start: false

README.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,30 @@
11
# size
2+
23
Limit size of POST requests for Gin framework
4+
5+
## Example
6+
7+
[embedmd]:# (example/main.go go)
8+
```go
9+
package main
10+
11+
import (
12+
"github.com/gin-contrib/size"
13+
"github.com/gin-gonic/gin"
14+
)
15+
16+
func handler(ctx *gin.Context) {
17+
val := ctx.PostForm("b")
18+
if len(ctx.Errors) > 0 {
19+
return
20+
}
21+
ctx.String(http.StatusOK, "got %s\n", val)
22+
}
23+
24+
func main() {
25+
rtr := gin.Default()
26+
rtr.Use(ratelimit.RateLimiter(10))
27+
rtr.POST("/", handler)
28+
rtr.Run(":8080")
29+
}
30+
```

example/main.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package main
2+
3+
import (
4+
"github.com/gin-contrib/size"
5+
"github.com/gin-gonic/gin"
6+
)
7+
8+
func handler(ctx *gin.Context) {
9+
val := ctx.PostForm("b")
10+
if len(ctx.Errors) > 0 {
11+
return
12+
}
13+
ctx.String(http.StatusOK, "got %s\n", val)
14+
}
15+
16+
func main() {
17+
rtr := gin.Default()
18+
rtr.Use(ratelimit.RateLimiter(10))
19+
rtr.POST("/", handler)
20+
rtr.Run(":8080")
21+
}

size.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package limits
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"net/http"
7+
8+
"github.com/gin-gonic/gin"
9+
)
10+
11+
type maxBytesReader struct {
12+
ctx *gin.Context
13+
rdr io.ReadCloser
14+
remaining int64
15+
wasAborted bool
16+
sawEOF bool
17+
}
18+
19+
func (mbr *maxBytesReader) tooLarge() (n int, err error) {
20+
n, err = 0, fmt.Errorf("HTTP request too large")
21+
22+
if !mbr.wasAborted {
23+
mbr.wasAborted = true
24+
ctx := mbr.ctx
25+
ctx.Error(err)
26+
ctx.Header("connection", "close")
27+
ctx.String(http.StatusRequestEntityTooLarge, "request too large")
28+
ctx.AbortWithStatus(http.StatusRequestEntityTooLarge)
29+
}
30+
return
31+
}
32+
33+
func (mbr *maxBytesReader) Read(p []byte) (n int, err error) {
34+
toRead := mbr.remaining
35+
if mbr.remaining == 0 {
36+
if mbr.sawEOF {
37+
return mbr.tooLarge()
38+
}
39+
// The underlying io.Reader may not return (0, io.EOF)
40+
// at EOF if the requested size is 0, so read 1 byte
41+
// instead. The io.Reader docs are a bit ambiguous
42+
// about the return value of Read when 0 bytes are
43+
// requested, and {bytes,strings}.Reader gets it wrong
44+
// too (it returns (0, nil) even at EOF).
45+
toRead = 1
46+
}
47+
if int64(len(p)) > toRead {
48+
p = p[:toRead]
49+
}
50+
n, err = mbr.rdr.Read(p)
51+
if err == io.EOF {
52+
mbr.sawEOF = true
53+
}
54+
if mbr.remaining == 0 {
55+
// If we had zero bytes to read remaining (but hadn't seen EOF)
56+
// and we get a byte here, that means we went over our limit.
57+
if n > 0 {
58+
return mbr.tooLarge()
59+
}
60+
return 0, err
61+
}
62+
mbr.remaining -= int64(n)
63+
if mbr.remaining < 0 {
64+
mbr.remaining = 0
65+
}
66+
return
67+
}
68+
69+
func (mbr *maxBytesReader) Close() error {
70+
return mbr.rdr.Close()
71+
}
72+
73+
// RateLimiter returns a middleware that limits the size of request
74+
// When a request is over the limit, the following will happen:
75+
// * Error will be added to the context
76+
// * Connection: close header will be set
77+
// * Error 413 will be send to client (http.StatusRequestEntityTooLarge)
78+
// * Current context will be aborted
79+
func RateLimiter(limit int64) gin.HandlerFunc {
80+
return func(ctx *gin.Context) {
81+
ctx.Request.Body = &maxBytesReader{
82+
ctx: ctx,
83+
rdr: ctx.Request.Body,
84+
remaining: limit,
85+
wasAborted: false,
86+
sawEOF: false,
87+
}
88+
ctx.Next()
89+
}
90+
}

size_test.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package limits
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"net/http"
7+
"os"
8+
"os/exec"
9+
"testing"
10+
"text/template"
11+
"time"
12+
)
13+
14+
var (
15+
params = struct {
16+
Size int
17+
Port int
18+
}{10, 9388}
19+
20+
codeFile = "/tmp/ratelimit_test_server.go"
21+
serverURL string
22+
)
23+
24+
func init() {
25+
tmpl := template.Must(template.ParseFiles("test_server.tmpl"))
26+
fp, err := os.Create(codeFile)
27+
if err != nil {
28+
panic(fmt.Errorf("can't open %s - %s", codeFile, err))
29+
}
30+
err = tmpl.Execute(fp, params)
31+
if err != nil {
32+
panic(fmt.Errorf("can't create %s - %s", codeFile, err))
33+
}
34+
serverURL = fmt.Sprintf("http://localhost:%d", params.Port)
35+
}
36+
37+
func waitForServer() error {
38+
timeout := 30 * time.Second
39+
ch := make(chan bool)
40+
go func() {
41+
for {
42+
_, err := http.Post(serverURL, "text/plain", nil)
43+
if err == nil {
44+
ch <- true
45+
}
46+
time.Sleep(10 * time.Millisecond)
47+
}
48+
}()
49+
50+
select {
51+
case <-ch:
52+
return nil
53+
case <-time.After(timeout):
54+
return fmt.Errorf("server did not reply after %v", timeout)
55+
}
56+
57+
}
58+
59+
func runServer() (*exec.Cmd, error) {
60+
cmd := exec.Command("go", "run", codeFile)
61+
cmd.Start()
62+
if err := waitForServer(); err != nil {
63+
return nil, err
64+
}
65+
return cmd, nil
66+
}
67+
68+
func doPost(val string) (*http.Response, error) {
69+
cmd, err := runServer()
70+
if err != nil {
71+
return nil, err
72+
}
73+
defer cmd.Process.Kill()
74+
75+
var buf bytes.Buffer
76+
fmt.Fprintf(&buf, "big=%s", val)
77+
return http.Post(serverURL, "application/x-www-form-urlencoded", &buf)
78+
}
79+
80+
func TestRateLimiterOK(t *testing.T) {
81+
resp, err := doPost("abc")
82+
if err != nil {
83+
t.Fatalf("error posting - %s", err)
84+
}
85+
if resp.StatusCode != http.StatusOK {
86+
t.Fatalf("bad status - %d", resp.StatusCode)
87+
}
88+
}
89+
90+
func TestRateLimiterOver(t *testing.T) {
91+
resp, err := doPost("abcdefghijklmnop")
92+
if err != nil {
93+
t.Fatalf("error posting - %s", err)
94+
}
95+
if resp.StatusCode != http.StatusRequestEntityTooLarge {
96+
t.Fatalf("bad status - %d", resp.StatusCode)
97+
}
98+
}

test_server.tmpl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package main
2+
3+
import (
4+
"net/http"
5+
6+
"github.com/gin-contrib/limits"
7+
"github.com/gin-gonic/gin"
8+
)
9+
10+
func handler(ctx *gin.Context) {
11+
val := ctx.PostForm("b")
12+
if len(ctx.Errors) > 0 {
13+
return
14+
}
15+
ctx.String(http.StatusOK, val)
16+
}
17+
18+
func main() {
19+
rtr := gin.Default()
20+
rtr.Use(limits.RateLimiter({{.Size}}))
21+
rtr.POST("/", handler)
22+
rtr.Run(":{{.Port}}")
23+
}

0 commit comments

Comments
 (0)