Skip to content

Commit 8e79bef

Browse files
committed
update code
1 parent 46dd2f7 commit 8e79bef

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+3675
-1151
lines changed

PromptHash_COCO.ps1

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
conda activate torch
2+
$lr = 0.001
3+
$gpu_rank = 0
4+
$valid_freq = 1
5+
$epochs = 100
6+
$res_name = "result/Result_PromptHash_COCO"
7+
$recon = 0.001
8+
$hyper_cls_inter = 20.0
9+
$hyper_quan = 1.0
10+
11+
python main.py --is-train --dataset coco --query-num 5000 --train-num 10000 --lr $lr --rank $gpu_rank --valid-freq $valid_freq --epochs $epochs --result-name $res_name --hyper-recon $recon --hyper-cls-inter $hyper_cls_inter --hyper-quan $hyper_quan

PromptHash_COCO.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
################# bash -x ***.sh #################
44
lr=0.001
5-
gpu_rank=0
5+
gpu_rank=3
66
valid_freq=1
77
epochs=100
88
res_name="result/Result_PromptHash_COCO"
99
recon=0.001
10+
hyper_cls_inter=20.0
11+
hyper_quan=1.0
1012

11-
python main.py --is-train --dataset coco --query-num 5000 --train-num 10000 --lr "$lr" --rank "$gpu_rank" --valid-freq "$valid_freq" --epochs "$epochs" --result-name "$res_name" --hyper-recon "$recon"
13+
python main.py --is-train --dataset coco --query-num 5000 --train-num 10000 --lr "$lr" --rank "$gpu_rank" --valid-freq "$valid_freq" --epochs "$epochs" --result-name "$res_name" --hyper-recon "$recon" --hyper-cls-inter "$hyper_cls_inter" --hyper-quan "$hyper_quan"

PromptHash_Flickr.ps1

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
conda activate torch
2+
$lr = 0.001
3+
$gpu_rank = 0
4+
$valid_freq = 1
5+
$epochs = 100
6+
$res_name = "result/Result_PromptHash_Flickr"
7+
$recon = 0.001
8+
$hyper_cls_inter = 5.0
9+
$hyper_quan = 0.1
10+
11+
python main.py --is-train --dataset flickr25k --query-num 2000 --train-num 10000 --lr $lr --rank $gpu_rank --valid-freq $valid_freq --epochs $epochs --result-name $res_name --hyper-recon $recon --hyper-cls-inter $hyper_cls_inter --hyper-quan $hyper_quan

PromptHash_Flickr.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
################# bash -x ***.sh #################
44
lr=0.001
5-
gpu_rank=0
5+
gpu_rank=7
66
valid_freq=1
77
epochs=100
88
res_name="result/Result_PromptHash_Flickr"
99
recon=0.001
10+
hyper_cls_inter=5.0
11+
hyper_quan=0.1
1012

11-
python main.py --is-train --dataset flickr25k --query-num 2000 --train-num 10000 --lr "$lr" --rank "$gpu_rank" --valid-freq "$valid_freq" --epochs "$epochs" --result-name "$res_name" --hyper-recon "$recon"
13+
python main.py --is-train --dataset flickr25k --query-num 2000 --train-num 10000 --lr "$lr" --rank "$gpu_rank" --valid-freq "$valid_freq" --epochs "$epochs" --result-name "$res_name" --hyper-recon "$recon" --hyper-cls-inter "$hyper_cls_inter" --hyper-quan "$hyper_quan"

PromptHash_NUSWIDE.ps1

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
conda activate torch
2+
$lr = 0.001
3+
$gpu_rank = 0
4+
$valid_freq = 1
5+
$epochs = 100
6+
$res_name = "result/Result_PromptHash_NUSWIDE"
7+
$recon = 0.001
8+
$hyper_cls_inter = 5.0
9+
$hyper_quan = 0.1
10+
11+
python main.py --is-train --dataset nuswide --caption-file caption.txt --query-num 2100 --train-num 10500 --lr $lr --rank $gpu_rank --valid-freq $valid_freq --epochs $epochs --result-name $res_name --hyper-recon $recon --hyper-cls-inter $hyper_cls_inter --hyper-quan $hyper_quan

PromptHash_NUSWIDE.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
################# bash -x ***.sh #################
44
lr=0.001
5-
gpu_rank=0
5+
gpu_rank=4
66
valid_freq=1
77
epochs=100
88
res_name="result/Result_PromptHash_NUSWIDE"
99
recon=0.001
10+
hyper_cls_inter=5.0
11+
hyper_quan=0.1
1012

11-
python main.py --is-train --dataset nuswide --caption-file caption.txt --query-num 2100 --train-num 10500 --lr "$lr" --rank "$gpu_rank" --valid-freq "$valid_freq" --epochs "$epochs" --result-name "$res_name" --hyper-recon "$recon"
13+
python main.py --is-train --dataset nuswide --caption-file caption.txt --query-num 2100 --train-num 10500 --lr "$lr" --rank "$gpu_rank" --valid-freq "$valid_freq" --epochs "$epochs" --result-name "$res_name" --hyper-recon "$recon" --hyper-cls-inter "$hyper_cls_inter" --hyper-quan "$hyper_quan"

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,22 @@ cd PromptHash
1414

1515
2. Please install the following packages:
1616
```bash
17-
conda create -n prompthash python=3.11 -y
17+
conda create -n prompthash python=3.13 -y
1818
conda activate prompthash
1919
```
2020

21-
3. Install PyTorch 2.3.1, mamba_ssm, and causal-conv1d:
21+
3. Install PyTorch 2.7.0, mamba_ssm, and causal-conv1d:
2222
```bash
23-
# Install PyTorch 2.3.1
24-
conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 -c pytorch
23+
# Install PyTorch 2.7.0
24+
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
2525

2626
# Install mamba_ssm
2727
# Please refer to https://github.com/state-spaces/mamba for detailed installation instructions
28+
# If you are using a CUDA 12.8 environment, you can download the pre-built whl file from the release page
2829

2930
# Install causal-conv1d
3031
# Please refer to https://github.com/Dao-AILab/causal-conv1d for detailed installation instructions
32+
# If you are using a CUDA 12.8 environment, you can download the pre-built whl file from the release page
3133
```
3234

3335
## Data 🗂️

dataset/README.md

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
### The generation of each mat file
2+
3+
You can use our pre-processed datasets, which makes it easier to get start.
4+
5+
To generate required `.mat` files, you could:
6+
1. Download cleaned datasets in `pan.baidu.com`:
7+
8+
link:https://pan.baidu.com/s/1jCYEBhm-bpikAh_Bti139g
9+
password:9idm
10+
11+
2. Move the downloaded `all_imgs.txt`, `all_tags.txt`, and `all_labels.txt` to `./dataset/XXXDatasetName/` as follows:
12+
```
13+
dataset
14+
├── coco
15+
│ ├── all_imgs.txt
16+
│ ├── all_tags.txt
17+
│ └── all_labels.txt
18+
├── flickr25k
19+
│ ├── all_imgs.txt
20+
│ ├── all_tags.txt
21+
│ └── all_labels.txt
22+
└── nuswide
23+
├── all_imgs.txt
24+
├── all_tags.txt
25+
└── all_labels.txt
26+
```
27+
3. Modify variable `img_root_path` in scripts `make_XXXDatasetName.py` to the absolute path of the directory, which contains all source images and is available at above provided `pan.baidu.com` link.
28+
4. Run scripts `make_XXXDatasetName.py` to generate corresponding `.mat` files. Then use these mat files to conduct experiment.
29+
30+
31+
32+
### (Optional) The meaning and format of each mat file
33+
34+
#### caption.mat
35+
For each dataset, `caption.mat` is data of text modality. It is a mat file with key `caption`.
36+
The shape of this mat is, i.e., `(20015,)` for MIRFlickr25K.
37+
Each element of this mat is a `string` that
38+
describes one image, i.e., "cigarette tattoos smoke red dress sunglasses" for `im1.jpg` in MIRFlickr25K dataset.
39+
40+
Note that 20,015 instances of MIRFlickr25K with 1,386 frequent textual tags and 190,421 instances of NUSWIDE with 1,000 frequent textual tags are used for experiments.
41+
42+
For MS COCO, we obtain 122,218 data points by removing the pairs without any label following DCHMT, and one of five sentences is randomly selected to form one image-text pair.
43+
44+
#### index.mat
45+
46+
`index.mat` is a mat file with key `index`. The shape is `(20015,)` for MIRFlickr25K.
47+
Each element is a `string` that indicates image path, i.e., "/path/flickr25k/im1.jpg".
48+
49+
#### label.mat
50+
51+
`label.mat` is a mat file with key `label`. The shape is `(20015, 24)` for MIRFlickr25K.
52+
Each element is a `numpy.ndarray`, i.e., `[0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 1. 1. 0. 0. 0. 0.]`.
53+
54+
For all dataset, the detailed data is showed as follows:
55+
56+
| Dataset | File name | Shape | One element |
57+
|:------------:|:------------------------:|:------------:|:--------------------------------------------:|
58+
| MIRFlickr25K | caption.mat | (20015,) | cigarette tattoos smoke red dress sunglasses |
59+
| MIRFlickr25K | index.mat | (20015,) | /path/im1.jpg |
60+
| MIRFlickr25K | label.mat | (20015, 24) | [0. 0. ... 0.] |
61+
| MS COCO | caption.mat | (122218,) | A woman cutting a large white sheet cake |
62+
| MS COCO | index.mat | (122218,) | /path/COCO_val2014_000000522418.jpg |
63+
| MS COCO | label.mat | (122218, 80) | [1. 0. ... 0.] |
64+
| NUSWIDE | caption.mat | (190421,) | portrait man flash sunglasses actor december |
65+
| NUSWIDE | index.mat | (190421,) | /path/0001_2124494179.jpg |
66+
| NUSWIDE | label.mat | (190421, 21) | [0. 0. ... 0.] |
67+
68+
You should generate these mat files in above format for experiments.

hash_model.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def replace_underscore(self, name_list):
7171
self._prompt_cache[cache_key] = prompt_ids
7272
return prompt_ids
7373

74-
@torch.cuda.amp.autocast()
74+
@torch.amp.autocast("cuda")
7575
def forward(self, classnames):
7676
batch_size = len(classnames)
7777

@@ -109,30 +109,9 @@ def __init__(self, num_layers=1, hidden_size=1024, nhead=4):
109109
self.inproj = nn.Linear(self.sigal_d, self.sigal_d)
110110
self.outproj = nn.Linear(self.sigal_d, self.sigal_d)
111111
self.mamba = MambaLayer(dim=self.sigal_d, d_state=16, d_conv=4, expand=2)
112-
# self.grn1 = nn.LayerNorm(self.sigal_d)
113-
# self.grn2 = nn.LayerNorm(self.d_model)
114112
self.grn1 = GRN(dim=self.sigal_d)
115113
self.grn2 = GRN(dim=self.d_model)
116114

117-
def weight_init(self):
118-
self.inproj.apply(self.kaiming_init)
119-
self.outproj.apply(self.kaiming_init)
120-
121-
def kaiming_init(self, m):
122-
classname = m.__class__.__name__
123-
if classname.find('Conv') != -1:
124-
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
125-
if m.bias is not None:
126-
init.constant_(m.bias, 0.0)
127-
elif classname.find('Linear') != -1:
128-
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
129-
if m.bias is not None:
130-
init.constant_(m.bias, 0.0)
131-
elif classname.find('Norm') != -1:
132-
init.normal_(m.weight.data, 1.0, 0.02)
133-
if m.bias is not None:
134-
init.constant_(m.bias.data, 0.0)
135-
136115
def forward(self, img_cls, txt_eos):
137116
short_img_cls = self.inproj(img_cls)
138117
short_txt_eos = self.inproj(txt_eos)
@@ -355,6 +334,7 @@ def __init__(self, class_name_list, layers_to_unfreeze, args=None):
355334

356335
# 解冻特定层
357336
for name, param in self.clip.named_parameters():
337+
# print(name)
358338
if name in layers_to_unfreeze:
359339
param.requires_grad = True
360340

load_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import unicode_literals
44
from __future__ import print_function
55

6-
from model.open_clip.simple_tokenizer import SimpleTokenizer
6+
from model.open_clip.tokenizer import SimpleTokenizer
77
import os
88
import numpy as np
99
import scipy.io as scio
@@ -153,7 +153,7 @@ def generate_dataset(captionFile: str,
153153
raise RuntimeError("text file is not support, we only read the keys of [caption, tags, YAll].")
154154
captions = captions[0] if captions.shape[0] == 1 else captions
155155
elif captionFile.endswith("txt"):
156-
with open(captionFile, "r") as f:
156+
with open(captionFile, 'r', encoding="utf-8") as f:
157157
captions = f.readlines()
158158
captions = np.asarray([[item.strip()] for item in captions])
159159
else:

model/modules.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44
from mamba_ssm import Mamba
5+
56
class GRN(nn.Module):
67
""" GRN (Global Response Normalization) layer
78
"""
@@ -36,12 +37,10 @@ def __init__(self, dim, d_state=16, d_conv=4, expand=2):
3637
self.dim = dim
3738
self.nin = nn.Linear(dim, dim)
3839
self.nin2 = nn.Linear(dim, dim)
39-
# self.norm2 = nn.LayerNorm(dim)
4040
self.norm2 = GRN1(dim=dim)
4141
self.act2 = nn.SiLU()
4242
self.act3 = nn.SiLU()
4343

44-
# self.norm = nn.LayerNorm(dim)
4544
self.norm = GRN1(dim=dim)
4645
self.act = nn.SiLU()
4746
self.mamba = Mamba(
@@ -88,8 +87,6 @@ def forward(self, x):
8887

8988
if __name__ == '__main__':
9089
mamba = MambaLayer(dim=512).cuda()
91-
92-
9390
input = torch.rand(32, 196, 512).cuda()
9491
output = mamba(input)
9592
print(input.size())

model/open_clip/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from .version import __version__
2+
13
from .coca_model import CoCa
24
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
35
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss

0 commit comments

Comments
 (0)