Skip to content

Commit 82700f6

Browse files
yasakova-anastasiadschoerkdschoerknmanovicMarishka17
authored
TransT tracker integration (#5226)
* AI tracker was one frame late * TransT tracker integration * fixed linter issues * added transt tracker to readme * clone a fixed transt version * nvidia/cuda:11.1-devel-ubuntu20.04 not available anymore, replaced with nvidia/cuda:11.7.0-devel-ubuntu20.04 * Fix show empty tasks (#100) * Fix show empty tasks * v1.41.1 * Update changelog Co-authored-by: Boris Sekachev <[email protected]> * [Snyk] Upgrade dotenv-webpack from 7.1.1 to 8.0.0 (#98) feat: upgrade dotenv-webpack from 7.1.1 to 8.0.0 Snyk has created this PR to upgrade dotenv-webpack from 7.1.1 to 8.0.0. See this package in npm: https://www.npmjs.com/package/dotenv-webpack See this project in Snyk: https://app.snyk.io/org/cvat/project/6c66365f-c154-46f2-b5db-4a4cd35fea4d?utm_source=github&utm_medium=referral&page=upgrade-pr Co-authored-by: snyk-bot <[email protected]> * Add repo disclaimer in README (#127) * Update README.md * Update README.md * Update tools-control.tsx * Add ModelHandler class * Small fixes Co-authored-by: dschoerk <[email protected]> Co-authored-by: Dominik Schörkhuber <[email protected]> Co-authored-by: Dominik Schörkhuber <[email protected]> Co-authored-by: Nikita Manovich <[email protected]> Co-authored-by: Maria Khrustaleva <[email protected]> Co-authored-by: Boris Sekachev <[email protected]> Co-authored-by: Andrey Zhavoronkov <[email protected]> Co-authored-by: snyk-bot <[email protected]> Co-authored-by: Maxim Zhiltsov <[email protected]>
1 parent c7125a8 commit 82700f6

File tree

5 files changed

+269
-0
lines changed

5 files changed

+269
-0
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ can be ran on:
176176
| [Text detection v4](/serverless/openvino/omz/intel/text-detection-0004/nuclio) | detector | OpenVINO | ✔️ | |
177177
| [YOLO v5](/serverless/pytorch/ultralytics/yolov5/nuclio) | detector | PyTorch | ✔️ | |
178178
| [SiamMask](/serverless/pytorch/foolwood/siammask/nuclio) | tracker | PyTorch | ✔️ | ✔️ |
179+
| [TransT](/serverless/pytorch/dschoerk/transt/nuclio) | tracker | PyTorch | ✔️ | ✔️ |
179180
| [f-BRS](/serverless/pytorch/saic-vul/fbrs/nuclio) | interactor | PyTorch | ✔️ | |
180181
| [HRNet](/serverless/pytorch/saic-vul/hrnet/nuclio) | interactor | PyTorch | | ✔️ |
181182
| [Inside-Outside Guidance](/serverless/pytorch/shiyinzhang/iog/nuclio) | interactor | PyTorch | ✔️ | |
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
metadata:
2+
name: pth-dschoerk-transt
3+
namespace: cvat
4+
annotations:
5+
name: TransT
6+
type: tracker
7+
spec:
8+
framework: pytorch
9+
10+
spec:
11+
description: Fast Online Object Tracking and Segmentation
12+
runtime: 'python:3.8'
13+
handler: main:handler
14+
eventTimeout: 30s
15+
env:
16+
- name: PYTHONPATH
17+
value: /opt/nuclio/trans-t
18+
19+
build:
20+
image: cvat/pth.dschoerk.transt
21+
baseImage: nvidia/cuda:11.7.0-devel-ubuntu20.04
22+
23+
directives:
24+
preCopy:
25+
- kind: ENV
26+
value: PATH="/root/miniconda3/bin:${PATH}"
27+
- kind: ARG
28+
value: PATH="/root/miniconda3/bin:${PATH}"
29+
- kind: RUN
30+
value: rm -f /etc/apt/sources.list.d/cuda.list /etc/apt/sources.list.d/nvidia-ml.list
31+
- kind: RUN
32+
value: apt update && apt install -y --no-install-recommends wget git ca-certificates libglib2.0-0 libgl1 && rm -rf /var/lib/apt/lists/* # libxrender1 libxext6
33+
- kind: RUN
34+
value: wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh &&
35+
chmod +x Miniconda3-latest-Linux-x86_64.sh && ./Miniconda3-latest-Linux-x86_64.sh -b &&
36+
rm -f Miniconda3-latest-Linux-x86_64.sh
37+
- kind: WORKDIR
38+
value: /opt/nuclio
39+
- kind: RUN
40+
value: conda create -y -n transt python=3.8
41+
- kind: SHELL
42+
value: '["conda", "run", "-n", "transt", "/bin/bash", "-c"]'
43+
- kind: RUN
44+
value: git clone https://github.com/dschoerk/TransT trans-t
45+
46+
- kind: RUN
47+
value: pip install torch==1.7.0+cu110 torchvision==0.8.1+cu110 -f https://download.pytorch.org/whl/torch_stable.html
48+
49+
- kind: RUN
50+
value: pip install jsonpickle opencv-python
51+
52+
- kind: RUN
53+
value: wget --no-check-certificate 'https://drive.google.com/uc?id=1Pq0sK-9jmbLAVtgB9-dPDc2pipCxYdM5' -O /transt.pth
54+
55+
- kind: RUN
56+
value: apt remove -y git wget
57+
- kind: RUN
58+
value: cd trans-t
59+
- kind: ENTRYPOINT
60+
value: '["conda", "run", "-n", "transt"]'
61+
62+
triggers:
63+
myHttpTrigger:
64+
maxWorkers: 1
65+
kind: 'http'
66+
workerAvailabilityTimeoutMilliseconds: 10000
67+
attributes:
68+
maxRequestBodySize: 33554432 # 32MB
69+
70+
resources:
71+
limits:
72+
nvidia.com/gpu: 1
73+
74+
platform:
75+
attributes:
76+
restartPolicy:
77+
name: always
78+
maximumRetryCount: 3
79+
mountMode: volume
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
metadata:
2+
name: pth-dschoerk-transt
3+
namespace: cvat
4+
annotations:
5+
name: TransT
6+
type: tracker
7+
spec:
8+
framework: pytorch
9+
10+
spec:
11+
description: Fast Online Object Tracking and Segmentation
12+
runtime: 'python:3.8'
13+
handler: main:handler
14+
eventTimeout: 30s
15+
env:
16+
- name: PYTHONPATH
17+
value: /opt/nuclio/trans-t
18+
19+
build:
20+
image: cvat/pth.dschoerk.transt
21+
baseImage: ubuntu:20.04
22+
23+
directives:
24+
preCopy:
25+
- kind: ENV
26+
value: PATH="/root/miniconda3/bin:${PATH}"
27+
- kind: ARG
28+
value: PATH="/root/miniconda3/bin:${PATH}"
29+
- kind: RUN
30+
value: rm -f /etc/apt/sources.list.d/cuda.list /etc/apt/sources.list.d/nvidia-ml.list
31+
- kind: RUN
32+
value: apt update && apt install -y --no-install-recommends wget git ca-certificates libglib2.0-0 libgl1 && rm -rf /var/lib/apt/lists/* # libxrender1 libxext6
33+
- kind: RUN
34+
value: wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh &&
35+
chmod +x Miniconda3-latest-Linux-x86_64.sh && ./Miniconda3-latest-Linux-x86_64.sh -b &&
36+
rm -f Miniconda3-latest-Linux-x86_64.sh
37+
- kind: WORKDIR
38+
value: /opt/nuclio
39+
- kind: RUN
40+
value: conda create -y -n transt python=3.8
41+
- kind: SHELL
42+
value: '["conda", "run", "-n", "transt", "/bin/bash", "-c"]'
43+
- kind: RUN
44+
value: git clone --depth 1 --branch v1.0 https://github.com/dschoerk/TransT trans-t
45+
46+
- kind: RUN
47+
value: pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
48+
49+
- kind: RUN
50+
value: pip install jsonpickle opencv-python
51+
52+
- kind: RUN
53+
value: wget --no-check-certificate 'https://drive.google.com/uc?id=1Pq0sK-9jmbLAVtgB9-dPDc2pipCxYdM5' -O /transt.pth
54+
55+
- kind: RUN
56+
value: apt remove -y git wget
57+
- kind: RUN
58+
value: cd trans-t
59+
- kind: ENTRYPOINT
60+
value: '["conda", "run", "-n", "transt"]'
61+
62+
triggers:
63+
myHttpTrigger:
64+
maxWorkers: 1
65+
kind: 'http'
66+
workerAvailabilityTimeoutMilliseconds: 10000
67+
attributes:
68+
maxRequestBodySize: 33554432 # 32MB
69+
70+
platform:
71+
attributes:
72+
restartPolicy:
73+
name: always
74+
maximumRetryCount: 3
75+
mountMode: volume
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import base64
2+
import io
3+
import json
4+
5+
import numpy as np
6+
from model_handler import ModelHandler
7+
from PIL import Image
8+
9+
10+
def init_context(context):
11+
context.logger.info("Init context... 0%")
12+
model = ModelHandler()
13+
context.user_data.model = model
14+
context.logger.info("Init context...100%")
15+
16+
def handler(context, event):
17+
context.logger.info("Run TransT model")
18+
data = event.body
19+
buf = io.BytesIO(base64.b64decode(data["image"]))
20+
shapes = data.get("shapes")
21+
states = data.get("states")
22+
23+
image = Image.open(buf).convert('RGB')
24+
image = np.array(image)[:, :, ::-1].copy()
25+
26+
results = {
27+
'shapes': [],
28+
'states': []
29+
}
30+
for i, shape in enumerate(shapes):
31+
shape, state = context.user_data.model.infer(image, shape, states[i] if i < len(states) else None)
32+
results['shapes'].append(shape)
33+
results['states'].append(state)
34+
35+
return context.Response(body=json.dumps(results), headers={},
36+
content_type='application/json', status_code=200)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (C) 2022 CVAT.ai Corporation
2+
#
3+
# SPDX-License-Identifier: MIT
4+
5+
import jsonpickle
6+
import numpy as np
7+
import torch
8+
from pysot_toolkit.bbox import get_axis_aligned_bbox
9+
from pysot_toolkit.trackers.net_wrappers import NetWithBackbone
10+
from pysot_toolkit.trackers.tracker import Tracker
11+
12+
13+
class ModelHandler:
14+
def __init__(self):
15+
use_gpu = torch.cuda.is_available()
16+
net_path = '/transt.pth' # Absolute path of the model
17+
net = NetWithBackbone(net_path=net_path, use_gpu=use_gpu)
18+
self.tracker = Tracker(name='transt', net=net, window_penalty=0.49, exemplar_size=128, instance_size=256)
19+
20+
def decode_state(self, state):
21+
self.tracker.net.net.zf = jsonpickle.decode(state['model.net.net.zf'])
22+
self.tracker.net.net.pos_template = jsonpickle.decode(state['model.net.net.pos_template'])
23+
24+
self.tracker.window = jsonpickle.decode(state['model.window'])
25+
self.tracker.center_pos = jsonpickle.decode(state['model.center_pos'])
26+
self.tracker.size = jsonpickle.decode(state['model.size'])
27+
self.tracker.channel_average = jsonpickle.decode(state['model.channel_average'])
28+
self.tracker.mean = jsonpickle.decode(state['model.mean'])
29+
self.tracker.std = jsonpickle.decode(state['model.std'])
30+
self.tracker.inplace = jsonpickle.decode(state['model.inplace'])
31+
32+
self.tracker.features_initialized = False
33+
if 'model.features_initialized' in state:
34+
self.tracker.features_initialized = jsonpickle.decode(state['model.features_initialized'])
35+
36+
def encode_state(self):
37+
state = {}
38+
state['model.net.net.zf'] = jsonpickle.encode(self.tracker.net.net.zf)
39+
state['model.net.net.pos_template'] = jsonpickle.encode(self.tracker.net.net.pos_template)
40+
state['model.window'] = jsonpickle.encode(self.tracker.window)
41+
state['model.center_pos'] = jsonpickle.encode(self.tracker.center_pos)
42+
state['model.size'] = jsonpickle.encode(self.tracker.size)
43+
state['model.channel_average'] = jsonpickle.encode(self.tracker.channel_average)
44+
state['model.mean'] = jsonpickle.encode(self.tracker.mean)
45+
state['model.std'] = jsonpickle.encode(self.tracker.std)
46+
state['model.inplace'] = jsonpickle.encode(self.tracker.inplace)
47+
state['model.features_initialized'] = jsonpickle.encode(getattr(self.tracker, 'features_initialized', False))
48+
49+
return state
50+
51+
def init_tracker(self, img, bbox):
52+
cx, cy, w, h = get_axis_aligned_bbox(np.array(bbox))
53+
gt_bbox_ = [cx - w / 2, cy - h / 2, w, h]
54+
init_info = {'init_bbox': gt_bbox_}
55+
self.tracker.initialize(img, init_info)
56+
57+
def track(self, img):
58+
outputs = self.tracker.track(img)
59+
prediction_bbox = outputs['target_bbox']
60+
61+
left = prediction_bbox[0]
62+
top = prediction_bbox[1]
63+
right = prediction_bbox[0] + prediction_bbox[2]
64+
bottom = prediction_bbox[1] + prediction_bbox[3]
65+
return (left, top, right, bottom)
66+
67+
def infer(self, image, shape, state):
68+
if state is None:
69+
init_shape = (shape[0], shape[1], shape[2] - shape[0], shape[3] - shape[1])
70+
71+
self.init_tracker(image, init_shape)
72+
state = self.encode_state()
73+
else:
74+
self.decode_state(state)
75+
shape = self.track(image)
76+
state = self.encode_state()
77+
78+
return shape, state

0 commit comments

Comments
 (0)