Skip to content

Commit ffc6b9c

Browse files
authored
Refactor layers for CLIP text encoder of SD model (#30)
* Refactor layers for CLIP text encoder of SD model * Update comments for return values of model loader. * Remove shared gate feedforward, which was due to a wrong implementation of quick GELU. * Remove SharedGatedFeedForward * Reformat loader.py
1 parent 475607a commit ffc6b9c

File tree

7 files changed

+178
-358
lines changed

7 files changed

+178
-358
lines changed

ai_edge_torch/generative/examples/stable_diffusion/clip.py

Lines changed: 83 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -15,65 +15,99 @@
1515

1616
import torch
1717
from torch import nn
18-
from torch._prims_common import mask_tensor
19-
from torch._prims_common.wrappers import out_wrapper
2018

21-
from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
19+
from ai_edge_torch.generative.layers.attention import TransformerBlock
20+
import ai_edge_torch.generative.layers.attention_utils as attention_utils
21+
import ai_edge_torch.generative.layers.builder as builder
22+
import ai_edge_torch.generative.layers.model_config as cfg
23+
import ai_edge_torch.generative.utilities.loader as loading_utils
24+
25+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
26+
ff_up_proj="layers.{}.linear_1",
27+
ff_down_proj="layers.{}.linear_2",
28+
ff_gate_proj="layers.{}.linear_1",
29+
attn_fused_qkv_proj="layers.{}.attention.in_proj",
30+
attn_output_proj="layers.{}.attention.out_proj",
31+
pre_attn_norm="layers.{}.layernorm_1",
32+
pre_ff_norm="layers.{}.layernorm_2",
33+
embedding="embedding.token_embedding",
34+
embedding_position="embedding.position_value",
35+
final_norm="layernorm",
36+
lm_head=None,
37+
)
2238

2339

24-
class CLIPEmbedding(nn.Module):
25-
26-
def __init__(self, n_vocab: int, n_embd: int, n_token: int):
27-
super().__init__()
28-
self.token_embedding = nn.Embedding(n_vocab, n_embd)
29-
self.position_value = nn.Parameter(torch.zeros((n_token, n_embd)))
30-
31-
def forward(self, tokens):
32-
x = self.token_embedding(tokens)
33-
x += self.position_value
34-
return x
35-
36-
37-
class CLIPLayer(nn.Module):
40+
class CLIP(nn.Module):
41+
"""CLIP text encoder
42+
For details, see https://arxiv.org/abs/2103.00020
43+
"""
3844

39-
def __init__(self, n_head: int, n_embd: int):
45+
def __init__(self, config: cfg.ModelConfig):
4046
super().__init__()
41-
self.layernorm_1 = nn.LayerNorm(n_embd)
42-
self.attention = SelfAttention(n_head, n_embd)
43-
self.layernorm_2 = nn.LayerNorm(n_embd)
44-
self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
45-
self.linear_2 = nn.Linear(4 * n_embd, n_embd)
46-
47-
def forward(self, x):
48-
residue = x
49-
x = self.layernorm_1(x)
50-
x = self.attention(x, causal_mask=True)
51-
x += residue
47+
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
48+
self.tok_embedding_position = nn.Parameter(
49+
torch.zeros((config.max_seq_len, config.embedding_dim))
50+
)
5251

53-
residue = x
54-
x = self.layernorm_2(x)
55-
x = self.linear_1(x)
56-
x = x * torch.sigmoid(1.702 * x) # QuickGELU activation function
57-
x = self.linear_2(x)
58-
x += residue
52+
self.config = config
53+
self.transformer_blocks = nn.ModuleList(
54+
TransformerBlock(config) for _ in range(config.num_layers)
55+
)
56+
self.final_norm = builder.build_norm(config.embedding_dim, config.final_norm_config)
5957

60-
return x
61-
62-
63-
class CLIP(nn.Module):
64-
65-
def __init__(self):
66-
super().__init__()
67-
self.embedding = CLIPEmbedding(49408, 768, 77)
68-
self.layers = nn.ModuleList([CLIPLayer(12, 768) for i in range(12)])
69-
self.layernorm = nn.LayerNorm(768)
58+
self.mask_cache = attention_utils.build_causal_mask_cache(
59+
size=config.max_seq_len, dtype=torch.float32
60+
)
7061

7162
@torch.inference_mode
7263
def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
7364
tokens = tokens.type(torch.long)
7465

75-
state = self.embedding(tokens)
76-
for layer in self.layers:
77-
state = layer(state)
78-
output = self.layernorm(state)
66+
state = self.tok_embedding(tokens) + self.tok_embedding_position
67+
for layer in self.transformer_blocks:
68+
state = layer(state, mask=self.mask_cache)
69+
output = self.final_norm(state)
7970
return output
71+
72+
73+
def get_model_config() -> cfg.ModelConfig:
74+
max_seq_len = 77
75+
vocab_size = 49408
76+
num_layers = 12
77+
num_heads = 12
78+
num_query_groups = 12
79+
embedding_dim = 768
80+
81+
attn_config = cfg.AttentionConfig(
82+
num_heads=num_heads,
83+
num_query_groups=num_query_groups,
84+
rotary_percentage=0.0,
85+
qkv_use_bias=True,
86+
qkv_transpose_before_split=True,
87+
output_proj_use_bias=True,
88+
enable_kv_cache=False,
89+
)
90+
91+
ff_config = cfg.FeedForwardConfig(
92+
type=cfg.FeedForwardType.SEQUENTIAL,
93+
activation=cfg.ActivationType.GELU_QUICK,
94+
intermediate_size=embedding_dim * 4,
95+
use_bias=True,
96+
)
97+
98+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
99+
100+
config = cfg.ModelConfig(
101+
vocab_size=vocab_size,
102+
num_layers=num_layers,
103+
max_seq_len=max_seq_len,
104+
embedding_dim=embedding_dim,
105+
attn_config=attn_config,
106+
ff_config=ff_config,
107+
pre_attention_norm_config=norm_config,
108+
pre_ff_norm_config=norm_config,
109+
final_norm_config=norm_config,
110+
enable_hlfb=True,
111+
)
112+
113+
return config

ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919
import torch
2020

2121
import ai_edge_torch
22-
from ai_edge_torch.generative.examples.stable_diffusion.clip import CLIP
22+
import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
2323
from ai_edge_torch.generative.examples.stable_diffusion.decoder import Decoder
2424
from ai_edge_torch.generative.examples.stable_diffusion.diffusion import Diffusion # NOQA
2525
from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
2626
import ai_edge_torch.generative.examples.stable_diffusion.util as util
27+
import ai_edge_torch.generative.utilities.loader as loading_utils
2728

2829

2930
@torch.inference_mode
@@ -36,8 +37,9 @@ def convert_stable_diffusion_to_tflite(
3637
image_width: int = 512,
3738
):
3839

39-
clip = CLIP()
40-
clip.load_state_dict(torch.load(clip_ckpt_path))
40+
clip_model = clip.CLIP(clip.get_model_config())
41+
loader = loading_utils.ModelLoader(clip_ckpt_path, clip.TENSOR_NAMES)
42+
loader.load(clip_model, strict=False)
4143

4244
encoder = Encoder()
4345
encoder.load_state_dict(torch.load(encoder_ckpt_path))
@@ -59,13 +61,13 @@ def convert_stable_diffusion_to_tflite(
5961
)
6062

6163
input_latents = encoder(input_image, noise)
62-
context_cond = clip(prompt_tokens)
64+
context_cond = clip_model(prompt_tokens)
6365
context_uncond = torch.zeros_like(context_cond)
6466
context = torch.cat([context_cond, context_uncond], axis=0)
6567
time_embedding = util.get_time_embedding(timestamp)
6668

6769
# CLIP text encoder
68-
ai_edge_torch.signature('encode', clip, (prompt_tokens,)).convert().export(
70+
ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert().export(
6971
'/tmp/stable_diffusion/clip.tflite'
7072
)
7173

0 commit comments

Comments
 (0)