-
Notifications
You must be signed in to change notification settings - Fork 3.2k
TransT tracker integration #4886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
81d3ed2
a466707
753e993
d7a3b32
93230b1
eb5235b
b7de6ae
3947b1f
af1e15b
75ba69b
f70aa2e
2f5f253
8870fed
bace6ea
1d7c131
2c7656d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
metadata: | ||
name: pth-dschoerk-transt | ||
namespace: cvat | ||
annotations: | ||
name: TransT | ||
type: tracker | ||
spec: | ||
framework: pytorch | ||
|
||
spec: | ||
description: Fast Online Object Tracking and Segmentation | ||
runtime: 'python:3.8' | ||
handler: main:handler | ||
eventTimeout: 30s | ||
env: | ||
- name: PYTHONPATH | ||
value: /opt/nuclio/trans-t | ||
|
||
build: | ||
image: cvat/pth.dschoerk.transt | ||
#baseImage: ubuntu:20.04 | ||
|
||
# GPU only | ||
baseImage: nvidia/cuda:11.7.0-devel-ubuntu20.04 | ||
|
||
directives: | ||
preCopy: | ||
- kind: ENV | ||
value: PATH="/root/miniconda3/bin:${PATH}" | ||
- kind: ARG | ||
value: PATH="/root/miniconda3/bin:${PATH}" | ||
- kind: RUN | ||
value: rm -f /etc/apt/sources.list.d/cuda.list /etc/apt/sources.list.d/nvidia-ml.list | ||
- kind: RUN | ||
value: apt update && apt install -y --no-install-recommends wget git ca-certificates libglib2.0-0 libgl1 && rm -rf /var/lib/apt/lists/* # libxrender1 libxext6 | ||
- kind: RUN | ||
value: wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && | ||
chmod +x Miniconda3-latest-Linux-x86_64.sh && ./Miniconda3-latest-Linux-x86_64.sh -b && | ||
rm -f Miniconda3-latest-Linux-x86_64.sh | ||
- kind: WORKDIR | ||
value: /opt/nuclio | ||
- kind: RUN | ||
value: conda create -y -n transt python=3.8 | ||
- kind: SHELL | ||
value: '["conda", "run", "-n", "transt", "/bin/bash", "-c"]' | ||
- kind: RUN | ||
value: git clone https://github.com/dschoerk/TransT trans-t | ||
|
||
#- kind: RUN | ||
# value: pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html | ||
|
||
# GPU only | ||
- kind: RUN | ||
value: pip install torch==1.7.0+cu110 torchvision==0.8.1+cu110 -f https://download.pytorch.org/whl/torch_stable.html | ||
|
||
|
||
- kind: RUN | ||
value: pip install jsonpickle opencv-python | ||
|
||
- kind: RUN | ||
value: wget --no-check-certificate 'https://drive.google.com/uc?id=1Pq0sK-9jmbLAVtgB9-dPDc2pipCxYdM5' -O /transt.pth | ||
|
||
- kind: RUN | ||
value: apt remove -y git wget | ||
- kind: RUN | ||
value: cd trans-t | ||
- kind: ENTRYPOINT | ||
value: '["conda", "run", "-n", "transt"]' | ||
|
||
triggers: | ||
myHttpTrigger: | ||
maxWorkers: 1 | ||
kind: 'http' | ||
workerAvailabilityTimeoutMilliseconds: 10000 | ||
attributes: | ||
maxRequestBodySize: 33554432 # 32MB | ||
|
||
# GPU only | ||
resources: | ||
limits: | ||
nvidia.com/gpu: 1 | ||
|
||
platform: | ||
attributes: | ||
restartPolicy: | ||
name: always | ||
maximumRetryCount: 3 | ||
mountMode: volume |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
metadata: | ||
name: pth-dschoerk-transt | ||
namespace: cvat | ||
annotations: | ||
name: TransT | ||
type: tracker | ||
spec: | ||
framework: pytorch | ||
|
||
spec: | ||
description: Fast Online Object Tracking and Segmentation | ||
runtime: 'python:3.8' | ||
handler: main:handler | ||
eventTimeout: 30s | ||
env: | ||
- name: PYTHONPATH | ||
value: /opt/nuclio/trans-t | ||
|
||
build: | ||
image: cvat/pth.dschoerk.transt | ||
baseImage: ubuntu:20.04 | ||
|
||
# GPU only | ||
# baseImage: nvidia/cuda:11.1-devel-ubuntu20.04 | ||
|
||
directives: | ||
preCopy: | ||
- kind: ENV | ||
value: PATH="/root/miniconda3/bin:${PATH}" | ||
- kind: ARG | ||
value: PATH="/root/miniconda3/bin:${PATH}" | ||
- kind: RUN | ||
value: rm -f /etc/apt/sources.list.d/cuda.list /etc/apt/sources.list.d/nvidia-ml.list | ||
- kind: RUN | ||
value: apt update && apt install -y --no-install-recommends wget git ca-certificates libglib2.0-0 libgl1 && rm -rf /var/lib/apt/lists/* # libxrender1 libxext6 | ||
- kind: RUN | ||
value: wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && | ||
chmod +x Miniconda3-latest-Linux-x86_64.sh && ./Miniconda3-latest-Linux-x86_64.sh -b && | ||
rm -f Miniconda3-latest-Linux-x86_64.sh | ||
- kind: WORKDIR | ||
value: /opt/nuclio | ||
- kind: RUN | ||
value: conda create -y -n transt python=3.8 | ||
- kind: SHELL | ||
value: '["conda", "run", "-n", "transt", "/bin/bash", "-c"]' | ||
- kind: RUN | ||
value: git clone --depth 1 --branch v1.0 https://github.com/dschoerk/TransT trans-t | ||
|
||
- kind: RUN | ||
value: pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html | ||
|
||
# GPU only | ||
#- kind: RUN | ||
# value: pip install torch==1.7.0+cu110 torchvision==0.8.1+cu110 -f https://download.pytorch.org/whl/torch_stable.html | ||
|
||
|
||
- kind: RUN | ||
value: pip install jsonpickle opencv-python | ||
|
||
- kind: RUN | ||
value: wget --no-check-certificate 'https://drive.google.com/uc?id=1Pq0sK-9jmbLAVtgB9-dPDc2pipCxYdM5' -O /transt.pth | ||
|
||
- kind: RUN | ||
value: apt remove -y git wget | ||
- kind: RUN | ||
value: cd trans-t | ||
- kind: ENTRYPOINT | ||
value: '["conda", "run", "-n", "transt"]' | ||
|
||
triggers: | ||
myHttpTrigger: | ||
maxWorkers: 1 | ||
kind: 'http' | ||
workerAvailabilityTimeoutMilliseconds: 10000 | ||
attributes: | ||
maxRequestBodySize: 33554432 # 32MB | ||
|
||
# GPU only | ||
# resources: | ||
# limits: | ||
# nvidia.com/gpu: 1 | ||
|
||
platform: | ||
attributes: | ||
restartPolicy: | ||
name: always | ||
maximumRetryCount: 3 | ||
mountMode: volume |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import json | ||
import base64 | ||
from PIL import Image | ||
import io | ||
import numpy as np | ||
import traceback | ||
import jsonpickle | ||
import torch | ||
|
||
from pysot_toolkit.bbox import get_axis_aligned_bbox | ||
from pysot_toolkit.trackers.tracker import Tracker | ||
from pysot_toolkit.trackers.net_wrappers import NetWithBackbone | ||
|
||
def create_tracker(): | ||
use_gpu = torch.cuda.is_available() | ||
|
||
net_path = '/transt.pth' # Absolute path of the model | ||
net = NetWithBackbone(net_path=net_path, use_gpu=use_gpu) | ||
tracker = Tracker(name='transt', net=net, window_penalty=0.49, exemplar_size=128, instance_size=256) | ||
return tracker | ||
|
||
def init_tracker(tracker, img, bbox): | ||
cx, cy, w, h = get_axis_aligned_bbox(np.array(bbox)) | ||
gt_bbox_ = [cx - w / 2, cy - h / 2, w, h] | ||
init_info = {'init_bbox': gt_bbox_} | ||
tracker.initialize(img, init_info) | ||
|
||
return tracker | ||
|
||
def track(tracker, img): | ||
outputs = tracker.track(img) | ||
prediction_bbox = outputs['target_bbox'] | ||
|
||
left = prediction_bbox[0] | ||
top = prediction_bbox[1] | ||
right = prediction_bbox[0] + prediction_bbox[2] | ||
bottom = prediction_bbox[1] + prediction_bbox[3] | ||
return tracker, (top, left, bottom, right) | ||
|
||
|
||
|
||
def init_context(context): | ||
context.logger.info("Init context... 0%") | ||
model = create_tracker() | ||
context.user_data.model = model | ||
context.logger.info("Init context...100%") | ||
|
||
def log(msg): | ||
#with open("/log.log", "a") as logf: | ||
# logf.write(msg+'\n') | ||
pass | ||
Comment on lines
+48
to
+51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this function can be removed. |
||
|
||
def encode_state(model): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please separate these functions into a separate |
||
state = {} | ||
state['model.net.net.zf'] = jsonpickle.encode(model.net.net.zf) | ||
state['model.net.net.pos_template'] = jsonpickle.encode(model.net.net.pos_template) | ||
|
||
#attrs = ['windows', 'center_pos', 'size', 'channel_average', 'mean', 'std', 'inplace', 'features_initialized'] | ||
|
||
state['model.window'] = jsonpickle.encode(model.window) | ||
state['model.center_pos'] = jsonpickle.encode(model.center_pos) | ||
state['model.size'] = jsonpickle.encode(model.size) | ||
state['model.channel_average'] = jsonpickle.encode(model.channel_average) | ||
state['model.mean'] = jsonpickle.encode(model.mean) | ||
state['model.std'] = jsonpickle.encode(model.std) | ||
state['model.inplace'] = jsonpickle.encode(model.inplace) | ||
state['model.features_initialized'] = jsonpickle.encode(getattr(model, 'features_initialized', False)) | ||
|
||
return state | ||
|
||
def decode_state(model, state): | ||
|
||
model.net.net.zf = jsonpickle.decode(state['model.net.net.zf']) | ||
model.net.net.pos_template = jsonpickle.decode(state['model.net.net.pos_template']) | ||
|
||
model.window = jsonpickle.decode(state['model.window']) | ||
model.center_pos = jsonpickle.decode(state['model.center_pos']) | ||
model.size = jsonpickle.decode(state['model.size']) | ||
model.channel_average = jsonpickle.decode(state['model.channel_average']) | ||
model.mean = jsonpickle.decode(state['model.mean']) | ||
model.std = jsonpickle.decode(state['model.std']) | ||
model.inplace = jsonpickle.decode(state['model.inplace']) | ||
|
||
model.features_initialized = False | ||
if 'model.features_initialized' in state: | ||
model.features_initialized = jsonpickle.decode(state['model.features_initialized']) | ||
|
||
return model | ||
|
||
def handler(context, event): | ||
|
||
try: | ||
context.logger.info("Run TransT model") | ||
data = event.body | ||
buf = io.BytesIO(base64.b64decode(data["image"])) | ||
shapes = data.get("shapes") | ||
states = data.get("states") | ||
|
||
image = Image.open(buf).convert('RGB') | ||
image = np.array(image)[:, :, ::-1].copy() | ||
|
||
#cv2.imwrite('/test.jpg', image) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove all useless comments. |
||
|
||
results = { | ||
'shapes': [], | ||
'states': [] | ||
} | ||
for i, shape in enumerate(shapes): | ||
if i >= len(states) or states[i] is None: | ||
init_shape = (shape[0], shape[1], shape[2]-shape[0], shape[3]-shape[1]) # x1,y1,x2,y2 -> x,y,w,h | ||
|
||
log('tracker init') | ||
log(str(init_shape)) | ||
|
||
#cv2.imwrite('/init_img.jpg', image) | ||
|
||
context.user_data.model = init_tracker(context.user_data.model, image, init_shape) | ||
state = encode_state(context.user_data.model) | ||
else: | ||
state = states[i] | ||
context.user_data.model = decode_state(context.user_data.model, state) | ||
context.user_data.model, (top, left, bottom, right) = track(context.user_data.model, image) | ||
|
||
shape = (left, top, right, bottom) | ||
state = encode_state(context.user_data.model) | ||
|
||
#cv2.imwrite('/track_img.jpg', image) | ||
|
||
log('tracked') | ||
log(str(shape)) | ||
|
||
results['shapes'].append(shape) | ||
results['states'].append(state) | ||
|
||
return context.Response(body=json.dumps(results), headers={}, | ||
content_type='application/json', status_code=200) | ||
|
||
except Exception as e: # cavemen debugging | ||
logf = open("/error.log", "w") | ||
logf.write(str(e)) | ||
logf.write(traceback.format_exc()) | ||
|
||
return context.Response(headers={}, | ||
content_type='application/json', status_code=666) | ||
Comment on lines
+138
to
+144
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should be removed. |
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a license to the beginning of the file.