Skip to content

Commit 7dd6944

Browse files
committed
Refactoring for ease of secondary development 2023-04-23
1 parent c7aa7ff commit 7dd6944

File tree

3 files changed

+280
-265
lines changed

3 files changed

+280
-265
lines changed

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,13 @@ wget,golang
1010

1111
# How Run
1212
```
13-
go build -o ~/go/bin/wget main.go
13+
go build -o ~/go/bin/wget main.g
14+
15+
~/go/bin/wget -u 'https://huggingface.co/Salesforce/codegen-16B-mono/resolve/main/pytorch_model.bin'
16+
1417
~/go/bin/wget -u "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_style_sd14v1.pth"
18+
shasum -a 256 t2iadapter_style_sd14v1.pth
19+
1520
```
1621

1722
<img width="800" src=https://user-images.githubusercontent.com/18223385/233818824-305fea4a-a5ed-4a70-8ade-33d6a8c3c734.gif>

main.go

Lines changed: 2 additions & 264 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,11 @@
11
package main
22

33
import (
4-
"errors"
54
"flag"
6-
"fmt"
7-
"io"
8-
"log"
9-
"net/http"
10-
"net/url"
11-
"os"
12-
"path/filepath"
13-
"runtime"
14-
"strconv"
15-
"strings"
16-
"sync"
17-
"time"
18-
19-
"github.com/cheggaaa/pb"
20-
"github.com/hktalent/PipelineHttp"
21-
)
22-
23-
type Worker struct {
24-
Url string
25-
File *os.File
26-
Count int64
27-
SyncWG sync.WaitGroup
28-
TotalSize int64
29-
Progress
30-
}
31-
32-
type Progress struct {
33-
Pool *pb.Pool
34-
Bars []*pb.ProgressBar
35-
}
36-
37-
var (
38-
pipelineHttp = PipelineHttp.NewPipelineHttp()
39-
sCurDir, err = os.Getwd()
5+
"github.com/hktalent/wget-go/pkg"
406
)
417

428
func main() {
43-
runtime.GOMAXPROCS(runtime.NumCPU())
44-
pipelineHttp.SetNoLimit()
459
//os.Args = []string{"", "-u", "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_style_sd14v1.pth"}
4610
var t = flag.Bool("t", false, "file name with datetime")
4711

@@ -51,231 +15,5 @@ func main() {
5115
flag.StringVar(&downloadUrl, "u", "", "Download URL")
5216
flag.Parse()
5317

54-
// Get header from the url
55-
log.Println("Url:", downloadUrl)
56-
szOldUrl := downloadUrl
57-
szFileName, s2, fileSize, err := getSizeAndCheckRangeSupport(downloadUrl)
58-
if nil != err {
59-
*workerCount = 1
60-
}
61-
if "" != s2 {
62-
downloadUrl = s2
63-
}
64-
log.Printf("File size: %d bytes, workerCount %d\n", fileSize, *workerCount)
65-
66-
var filePath string
67-
if *t {
68-
filePath = sCurDir + string(filepath.Separator) + strconv.FormatInt(time.Now().UnixNano(), 10) + "_" + getFileName(downloadUrl)
69-
} else {
70-
if "" != out {
71-
filePath = sCurDir + string(filepath.Separator) + out
72-
} else if "" != szFileName {
73-
filePath = sCurDir + string(filepath.Separator) + szFileName
74-
} else {
75-
filePath = sCurDir + string(filepath.Separator) + getFileName(szOldUrl)
76-
}
77-
}
78-
log.Printf("Local path: %s\n", filePath)
79-
80-
// 这里后期需要优化,当异常后第二次运行,从断点开始的情况 os.O_APPEND|
81-
f, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR, 0666)
82-
handleError(err)
83-
defer f.Close()
84-
85-
// New worker struct to download file
86-
var worker = Worker{
87-
Url: downloadUrl,
88-
File: f,
89-
Count: *workerCount,
90-
TotalSize: fileSize,
91-
}
92-
93-
var start, end, partialSize int64
94-
if 0 < fileSize%(*workerCount) {
95-
partialSize = fileSize / (*workerCount - 1)
96-
} else {
97-
partialSize = fileSize / *workerCount
98-
}
99-
now := time.Now().UTC()
100-
for num := int64(0); num < worker.Count; num++ {
101-
// New sub progress bar (give it 0 at first for new instance and assign real size later on.)
102-
bar := pb.New(0).Prefix(fmt.Sprintf("Part %d 0%% ", num+1))
103-
bar.ShowSpeed = true
104-
bar.SetMaxWidth(100)
105-
bar.SetUnits(pb.U_BYTES_DEC)
106-
bar.SetRefreshRate(time.Second)
107-
bar.ShowPercent = true
108-
worker.Progress.Bars = append(worker.Progress.Bars, bar)
109-
110-
if num == worker.Count {
111-
end = fileSize // last part
112-
} else {
113-
end = start + partialSize
114-
if end > fileSize {
115-
end = fileSize
116-
}
117-
}
118-
119-
worker.SyncWG.Add(1)
120-
go worker.writeRange(num, start, end-1)
121-
start = end
122-
}
123-
worker.Progress.Pool, err = pb.StartPool(worker.Progress.Bars...)
124-
handleError(err)
125-
worker.SyncWG.Wait()
126-
worker.Progress.Pool.Stop()
127-
log.Println("Elapsed time:", time.Since(now))
128-
log.Println("Done!")
129-
}
130-
131-
func (w *Worker) writeRange(partNum int64, start int64, end int64) {
132-
var written int64
133-
134-
defer w.Bars[partNum].Finish()
135-
defer w.SyncWG.Done()
136-
if start >= end {
137-
return
138-
}
139-
body, size, err := w.getRangeBody(start, end)
140-
if err != nil {
141-
log.Fatalf("Part %d request error: %s\n", partNum+1, err.Error())
142-
}
143-
defer body.Close()
144-
145-
// Assign total size to progress bar
146-
w.Bars[partNum].Total = size
147-
148-
// New percentage flag
149-
percentFlag := map[int64]bool{}
150-
151-
// make a buffer to keep chunks that are read
152-
buf := make([]byte, 8*1024)
153-
for {
154-
nr, er := body.Read(buf)
155-
if nr > 0 {
156-
nw, err := w.File.WriteAt(buf[0:nr], start)
157-
if err != nil {
158-
log.Fatalf("Part %d occured error: %s.\n", partNum+1, err.Error())
159-
}
160-
if nr != nw {
161-
log.Fatalf("Part %d occured error of short writiing.\n", partNum+1)
162-
}
163-
164-
start = int64(nw) + start
165-
if nw > 0 {
166-
written += int64(nw)
167-
}
168-
169-
// Update written bytes on progress bar
170-
w.Bars[int(partNum)].Set64(written)
171-
172-
// Update current percentage on progress bars
173-
p := int64(float32(written) / float32(size) * 100)
174-
_, flagged := percentFlag[p]
175-
if !flagged {
176-
percentFlag[p] = true
177-
//w.Bars[int(partNum)].Prefix(fmt.Sprintf("Part %d(%d - %d) %d%% ", partNum+1, start, end+1, p))
178-
w.Bars[int(partNum)].Prefix(fmt.Sprintf("Part %d %d%% ", partNum+1, p))
179-
}
180-
}
181-
if er != nil {
182-
if er.Error() == "EOF" {
183-
if size == written {
184-
// Download successfully
185-
} else {
186-
handleError(errors.New(fmt.Sprintf("Part %d unfinished.\n", partNum+1)))
187-
}
188-
break
189-
}
190-
handleError(errors.New(fmt.Sprintf("Part %d occured error: %s\n", partNum+1, er.Error())))
191-
}
192-
}
193-
}
194-
195-
func (w *Worker) getRangeBody(start int64, end int64) (io.ReadCloser, int64, error) {
196-
//var client http.Client
197-
req, err := http.NewRequest("GET", w.Url, nil)
198-
// req.Header.Set("cookie", "")
199-
//log.Printf("Request header: %s\n", req.Header)
200-
if err != nil {
201-
return nil, 0, err
202-
}
203-
204-
// Set range header
205-
req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", start, end))
206-
resp, err := pipelineHttp.GetRawClient4Http2().Do(req) // myClient
207-
if err != nil {
208-
return nil, 0, err
209-
}
210-
size, err := strconv.ParseInt(resp.Header["Content-Length"][0], 10, 64)
211-
return resp.Body, size, err
212-
}
213-
214-
func getHds(header http.Header, a ...string) string {
215-
for _, x := range a {
216-
if s, ok := header[x]; ok {
217-
if 0 < len(s) {
218-
return strings.Join(s, ",")
219-
}
220-
}
221-
}
222-
return ""
223-
}
224-
225-
/*
226-
1. Check if the URL supports Accept Ranges
227-
2. Confirm the size of downloaded resources
228-
*/
229-
func getSizeAndCheckRangeSupport(szUrl1 string) (szFileName, szUrl string, size int64, err error) {
230-
req, err := http.NewRequest("HEAD", szUrl1, nil)
231-
if err != nil {
232-
return
233-
}
234-
// req.Header.Set("cookie", "")
235-
// log.Printf("Request header: %s\n", req.Header)
236-
res, err := pipelineHttp.GetRawClient4Http2().Do(req)
237-
if err != nil {
238-
return
239-
}
240-
defer res.Body.Close()
241-
header := res.Header
242-
if s2 := getHds(header, "Location"); "" != s2 {
243-
return getSizeAndCheckRangeSupport(s2)
244-
}
245-
246-
acceptRanges, supported := header["Accept-Ranges"]
247-
if !supported {
248-
return "", szUrl1, 0, errors.New("Doesn't support header `Accept-Ranges`.")
249-
} else if supported && acceptRanges[0] != "bytes" {
250-
return "", szUrl1, 0, errors.New("Support `Accept-Ranges`, but value is not `bytes`.")
251-
}
252-
if s1 := getHds(header, "Content-Length", "X-Linked-Size"); "" != s1 {
253-
size, err = strconv.ParseInt(s1, 10, 64)
254-
}
255-
szUrl = szUrl1
256-
// attachment; filename*=UTF-8''t2iadapter_style_sd14v1.pth; filename="t2iadapter_style_sd14v1.pth";
257-
if s1 := getHds(header, "Content-Disposition"); "" != s1 {
258-
a := strings.Split(s1, "; ")
259-
k := "filename*=UTF-8''"
260-
for _, x := range a {
261-
if strings.HasPrefix(x, k) {
262-
szFileName, err = url.QueryUnescape(x[len(k):])
263-
break
264-
}
265-
}
266-
}
267-
return
268-
}
269-
270-
func getFileName(downloadUrl string) string {
271-
urlStruct, err := url.Parse(downloadUrl)
272-
handleError(err)
273-
return filepath.Base(urlStruct.Path)
274-
}
275-
276-
func handleError(err error) {
277-
if err != nil {
278-
log.Println("err:", err)
279-
os.Exit(1)
280-
}
18+
pkg.Main(t, downloadUrl, out, workerCount)
28119
}

0 commit comments

Comments
 (0)