Skip to content

Commit f57e99e

Browse files
authored
Migrate stable diffusion example to ai-torch-edge (#14)
Refactoring will follow to move the reused modules into layers directory.
1 parent a0e1125 commit f57e99e

File tree

10 files changed

+1132
-0
lines changed

10 files changed

+1132
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2024 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright 2024 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
import math
17+
18+
import torch
19+
from torch import _decomp
20+
from torch import nn
21+
from torch._prims_common import mask_tensor
22+
from torch._prims_common.wrappers import out_wrapper
23+
from torch.nn import functional as F
24+
25+
26+
def triu(a):
27+
h, w = a.shape[-2:]
28+
mask = (
29+
torch.arange(w, device=a.device).unsqueeze(-2)
30+
- torch.arange(h, device=a.device).unsqueeze(-1)
31+
) >= 1
32+
mask = torch.broadcast_to(mask, a.shape)
33+
return torch.ops.aten.logical_and(a, mask).contiguous()
34+
35+
36+
# _decomp.decomposition_table[torch.ops.aten.triu.default] = triu
37+
38+
39+
class SelfAttention(nn.Module):
40+
41+
def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
42+
super().__init__()
43+
self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
44+
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
45+
self.n_heads = n_heads
46+
self.d_head = d_embed // n_heads
47+
48+
def forward(self, x, causal_mask=False):
49+
input_shape = x.shape
50+
batch_size, sequence_length, d_embed = input_shape
51+
interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
52+
53+
q, k, v = self.in_proj(x).chunk(3, dim=-1)
54+
55+
q = q.view(interim_shape).transpose(1, 2)
56+
k = k.view(interim_shape).transpose(1, 2)
57+
v = v.view(interim_shape).transpose(1, 2)
58+
59+
weight = q @ k.transpose(-1, -2)
60+
if causal_mask:
61+
# mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
62+
mask = triu(torch.ones_like(weight, dtype=torch.bool))
63+
weight.masked_fill_(mask, -torch.inf)
64+
weight /= math.sqrt(self.d_head)
65+
weight = F.softmax(weight, dim=-1)
66+
67+
output = weight @ v
68+
output = output.transpose(1, 2)
69+
output = output.reshape(input_shape)
70+
output = self.out_proj(output)
71+
return output
72+
73+
74+
class CrossAttention(nn.Module):
75+
76+
def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
77+
super().__init__()
78+
self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
79+
self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
80+
self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
81+
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
82+
self.n_heads = n_heads
83+
self.d_head = d_embed // n_heads
84+
85+
def forward(self, x, y):
86+
input_shape = x.shape
87+
batch_size, sequence_length, d_embed = input_shape
88+
interim_shape = (batch_size, -1, self.n_heads, self.d_head)
89+
90+
q = self.q_proj(x)
91+
k = self.k_proj(y)
92+
v = self.v_proj(y)
93+
94+
q = q.view(interim_shape).transpose(1, 2)
95+
k = k.view(interim_shape).transpose(1, 2)
96+
v = v.view(interim_shape).transpose(1, 2)
97+
98+
weight = q @ k.transpose(-1, -2)
99+
weight /= math.sqrt(self.d_head)
100+
weight = F.softmax(weight, dim=-1)
101+
102+
output = weight @ v
103+
output = output.transpose(1, 2).contiguous()
104+
output = output.view(input_shape)
105+
output = self.out_proj(output)
106+
return output
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2024 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
import torch
17+
from torch import nn
18+
from torch._prims_common import mask_tensor
19+
from torch._prims_common.wrappers import out_wrapper
20+
21+
from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
22+
23+
24+
class CLIPEmbedding(nn.Module):
25+
26+
def __init__(self, n_vocab: int, n_embd: int, n_token: int):
27+
super().__init__()
28+
self.token_embedding = nn.Embedding(n_vocab, n_embd)
29+
self.position_value = nn.Parameter(torch.zeros((n_token, n_embd)))
30+
31+
def forward(self, tokens):
32+
x = self.token_embedding(tokens)
33+
x += self.position_value
34+
return x
35+
36+
37+
class CLIPLayer(nn.Module):
38+
39+
def __init__(self, n_head: int, n_embd: int):
40+
super().__init__()
41+
self.layernorm_1 = nn.LayerNorm(n_embd)
42+
self.attention = SelfAttention(n_head, n_embd)
43+
self.layernorm_2 = nn.LayerNorm(n_embd)
44+
self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
45+
self.linear_2 = nn.Linear(4 * n_embd, n_embd)
46+
47+
def forward(self, x):
48+
residue = x
49+
x = self.layernorm_1(x)
50+
x = self.attention(x, causal_mask=True)
51+
x += residue
52+
53+
residue = x
54+
x = self.layernorm_2(x)
55+
x = self.linear_1(x)
56+
x = x * torch.sigmoid(1.702 * x) # QuickGELU activation function
57+
x = self.linear_2(x)
58+
x += residue
59+
60+
return x
61+
62+
63+
class CLIP(nn.Module):
64+
65+
def __init__(self):
66+
super().__init__()
67+
self.embedding = CLIPEmbedding(49408, 768, 77)
68+
self.layers = nn.ModuleList([CLIPLayer(12, 768) for i in range(12)])
69+
self.layernorm = nn.LayerNorm(768)
70+
71+
def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
72+
tokens = tokens.type(torch.long)
73+
74+
state = self.embedding(tokens)
75+
for layer in self.layers:
76+
state = layer(state)
77+
output = self.layernorm(state)
78+
return output
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2024 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from torch import nn
17+
from torch.nn import functional as F
18+
19+
from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
20+
21+
22+
class AttentionBlock(nn.Module):
23+
24+
def __init__(self, channels):
25+
super().__init__()
26+
self.groupnorm = nn.GroupNorm(32, channels)
27+
self.attention = SelfAttention(1, channels)
28+
29+
def forward(self, x):
30+
residue = x
31+
x = self.groupnorm(x)
32+
33+
n, c, h, w = x.shape
34+
x = x.view((n, c, h * w))
35+
x = x.transpose(-1, -2)
36+
x = self.attention(x)
37+
x = x.transpose(-1, -2)
38+
x = x.view((n, c, h, w))
39+
40+
x += residue
41+
return x
42+
43+
44+
class ResidualBlock(nn.Module):
45+
46+
def __init__(self, in_channels, out_channels):
47+
super().__init__()
48+
self.groupnorm_1 = nn.GroupNorm(32, in_channels)
49+
self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
50+
51+
self.groupnorm_2 = nn.GroupNorm(32, out_channels)
52+
self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
53+
54+
if in_channels == out_channels:
55+
self.residual_layer = nn.Identity()
56+
else:
57+
self.residual_layer = nn.Conv2d(
58+
in_channels, out_channels, kernel_size=1, padding=0
59+
)
60+
61+
def forward(self, x):
62+
residue = x
63+
64+
x = self.groupnorm_1(x)
65+
x = F.silu(x)
66+
x = self.conv_1(x)
67+
68+
x = self.groupnorm_2(x)
69+
x = F.silu(x)
70+
x = self.conv_2(x)
71+
72+
return x + self.residual_layer(residue)
73+
74+
75+
class Decoder(nn.Sequential):
76+
77+
def __init__(self):
78+
super().__init__(
79+
nn.Conv2d(4, 4, kernel_size=1, padding=0),
80+
nn.Conv2d(4, 512, kernel_size=3, padding=1),
81+
ResidualBlock(512, 512),
82+
AttentionBlock(512),
83+
ResidualBlock(512, 512),
84+
ResidualBlock(512, 512),
85+
ResidualBlock(512, 512),
86+
ResidualBlock(512, 512),
87+
nn.Upsample(scale_factor=2),
88+
nn.Conv2d(512, 512, kernel_size=3, padding=1),
89+
ResidualBlock(512, 512),
90+
ResidualBlock(512, 512),
91+
ResidualBlock(512, 512),
92+
nn.Upsample(scale_factor=2),
93+
nn.Conv2d(512, 512, kernel_size=3, padding=1),
94+
ResidualBlock(512, 256),
95+
ResidualBlock(256, 256),
96+
ResidualBlock(256, 256),
97+
nn.Upsample(scale_factor=2),
98+
nn.Conv2d(256, 256, kernel_size=3, padding=1),
99+
ResidualBlock(256, 128),
100+
ResidualBlock(128, 128),
101+
ResidualBlock(128, 128),
102+
nn.GroupNorm(32, 128),
103+
nn.SiLU(),
104+
nn.Conv2d(128, 3, kernel_size=3, padding=1),
105+
)
106+
107+
def forward(self, x):
108+
x = x / 0.18215
109+
for module in self:
110+
x = module(x)
111+
return x

0 commit comments

Comments
 (0)