Skip to content

Commit 35cf9f8

Browse files
authored
Merge pull request #4 from zeroxoxo/zeroxoxo-patch-1
Major update #1
2 parents 49174ff + 2f015c1 commit 35cf9f8

File tree

4 files changed

+348
-187
lines changed

4 files changed

+348
-187
lines changed

FastStyleTransferNode.py

Lines changed: 111 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,22 @@
11
"""
2-
This node is a simple conversion of this repository into ComfyUI ecosystem:
2+
These nodes are a simple conversion of these repositories into ComfyUI ecosystem:
33
https://github.com/rrmina/fast-neural-style-pytorch.git
4+
https://github.com/gordicaleksa/pytorch-neural-style-transfer.git
45
5-
Some of the code is written by ChatGPT4-o
66
"""
77

88
import torch
99
import torch.nn as nn
10-
from torchvision import models
1110
import time
1211
import os
1312
import folder_paths
1413
import subprocess as sp
1514
import sys
1615

17-
18-
19-
20-
class VGG16(nn.Module):
21-
def __init__(self, vgg_path=os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/vgg/vgg16-00b39a1b.pth"), train=False):
22-
super(VGG16, self).__init__()
23-
# Load VGG Skeleton, Pretrained Weights
24-
vgg16_features = models.vgg16(pretrained=False)
25-
vgg16_features.load_state_dict(torch.load(vgg_path), strict=False)
26-
self.features = vgg16_features.features
27-
28-
# Turn-off Gradient History (on for testing train)
29-
for param in self.features.parameters():
30-
param.requires_grad = train
31-
32-
def forward(self, x):
33-
layers = {'3': 'relu1_2', '8': 'relu2_2', '15': 'relu3_3', '22': 'relu4_3'}
34-
features = {}
35-
for name, layer in self.features._modules.items():
36-
x = layer(x)
37-
if name in layers:
38-
features[layers[name]] = x
39-
if (name=='22'):
40-
break
41-
42-
return features
43-
44-
45-
class ConvLayer(nn.Module):
16+
# ML classes
17+
class ConvolutionalLayer(nn.Module):
4618
def __init__(self, in_channels, out_channels, kernel_size, stride, norm="instance"):
47-
super(ConvLayer, self).__init__()
19+
super(ConvolutionalLayer, self).__init__()
4820
# Padding Layers
4921
self.padding_size = kernel_size // 2
5022
self.reflection_pad = nn.ReflectionPad2d(self.padding_size)
@@ -77,9 +49,9 @@ class ResidualLayer(nn.Module):
7749
"""
7850
def __init__(self, channels=128, kernel_size=3):
7951
super(ResidualLayer, self).__init__()
80-
self.conv1 = ConvLayer(channels, channels, kernel_size, stride=1)
52+
self.conv1 = ConvolutionalLayer(channels, channels, kernel_size, stride=1)
8153
self.relu = nn.ReLU()
82-
self.conv2 = ConvLayer(channels, channels, kernel_size, stride=1)
54+
self.conv2 = ConvolutionalLayer(channels, channels, kernel_size, stride=1)
8355

8456
def forward(self, x):
8557
identity = x # preserve residual
@@ -89,9 +61,9 @@ def forward(self, x):
8961
return out
9062

9163

92-
class DeconvLayer(nn.Module):
64+
class DeconvolutionalLayer(nn.Module):
9365
def __init__(self, in_channels, out_channels, kernel_size, stride, output_padding, norm="instance"):
94-
super(DeconvLayer, self).__init__()
66+
super(DeconvolutionalLayer, self).__init__()
9567

9668
# Transposed Convolution
9769
padding_size = kernel_size // 2
@@ -121,11 +93,11 @@ class TransformerNetwork(nn.Module):
12193
def __init__(self):
12294
super(TransformerNetwork, self).__init__()
12395
self.ConvBlock = nn.Sequential(
124-
ConvLayer(3, 32, 9, 1),
96+
ConvolutionalLayer(3, 32, 9, 1),
12597
nn.ReLU(),
126-
ConvLayer(32, 64, 3, 2),
98+
ConvolutionalLayer(32, 64, 3, 2),
12799
nn.ReLU(),
128-
ConvLayer(64, 128, 3, 2),
100+
ConvolutionalLayer(64, 128, 3, 2),
129101
nn.ReLU()
130102
)
131103
self.ResidualBlock = nn.Sequential(
@@ -136,11 +108,11 @@ def __init__(self):
136108
ResidualLayer(128, 3)
137109
)
138110
self.DeconvBlock = nn.Sequential(
139-
DeconvLayer(128, 64, 3, 2, 1),
111+
DeconvolutionalLayer(128, 64, 3, 2, 1),
140112
nn.ReLU(),
141-
DeconvLayer(64, 32, 3, 2, 1),
113+
DeconvolutionalLayer(64, 32, 3, 2, 1),
142114
nn.ReLU(),
143-
ConvLayer(32, 3, 9, 1, norm="None")
115+
ConvolutionalLayer(32, 3, 9, 1, norm="None")
144116
)
145117

146118
def forward(self, x):
@@ -150,25 +122,27 @@ def forward(self, x):
150122
return out
151123

152124

125+
# Node classes
153126
class TrainFastStyleTransfer:
154127
def __init__(self):
155128
pass
156129

157130
@classmethod
158131
def INPUT_TYPES(s):
159-
input_dir = folder_paths.get_input_directory()
160-
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
161132
return {
162133
"required": {
163-
"style_img": (sorted(files), {"image_upload": True}),
134+
"style_img": ("IMAGE",),
164135
"seed": ("INT", {"default": 30, "min": 0, "max": 999999, "step": 1,}),
165136
"content_weight": ("INT", {"default": 14, "min": 1, "max": 128, "step": 1,}),
166137
"style_weight": ("INT", {"default": 50, "min": 1, "max": 128, "step": 1,}),
167-
"batch_size": ("INT", {"default": 4, "min": 1, "max": 128, "step": 1,}),
168-
"train_img_size": ("INT", {"default": 256, "min": 256, "max": 2048, "step": 1,}),
169-
"learning_rate": ("FLOAT", {"default": 0.001, "min": 0.0001, "max": 0.1, "step": 0.0001}),
138+
"tv_weight": ("FLOAT", {"default": 0.001, "min": 0.0, "max": 1.0, "step": 0.0000001}),
139+
"batch_size": ("INT", {"default": 4, "min": 1, "max": 32, "step": 1,}),
140+
"train_img_size": ("INT", {"default": 256, "min": 128, "max": 2048, "step": 1,}),
141+
"learning_rate": ("FLOAT", {"default": 0.001, "min": 0.0001, "max": 100.0, "step": 0.0001}),
170142
"num_epochs": ("INT", {"default": 1, "min": 1, "max": 20, "step": 1,}),
171-
"save_model_every": ("INT", {"default": 500, "min": 100, "max": 10000, "step": 1,}),
143+
"save_model_every": ("INT", {"default": 500, "min": 10, "max": 10000, "step": 10,}),
144+
"from_pretrained": ("INT", {"default": 0, "min": 0, "max": 1, "step": 1,}),
145+
"model": ([file for file in os.listdir(os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/models/")) if file.endswith('.pth')], ),
172146
},
173147
}
174148

@@ -179,14 +153,19 @@ def INPUT_TYPES(s):
179153

180154
CATEGORY = "Style Transfer"
181155

156+
def encode_tensor(self, tensor):
157+
tensor = tensor.permute(0, 3, 1, 2).contiguous() # Convert to [batch_size, channels, height, width]
158+
return tensor[:, [2, 1, 0], :, :] * 255
182159

183-
def train(self, style_img, seed, batch_size, train_img_size, learning_rate, num_epochs, content_weight, style_weight, save_model_every):
160+
161+
def train(self, style_img, seed, batch_size, train_img_size, learning_rate, num_epochs, content_weight, style_weight, tv_weight, save_model_every, from_pretrained, model):
162+
temp_save_style_img = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/temp/") + "temp_save_content_img.pt"
184163
save_model_path = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/models/")
185164
dataset_path = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/dataset/")
186165
vgg_path = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/vgg/vgg16-00b39a1b.pth")
187166
save_image_path = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/output/")
188-
style_image_path = folder_paths.get_annotated_filepath(style_img)
189167
train_path = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/train.py")
168+
190169

191170

192171
command = [
@@ -195,24 +174,31 @@ def train(self, style_img, seed, batch_size, train_img_size, learning_rate, num_
195174
'--dataset_path', dataset_path,
196175
'--vgg_path', vgg_path,
197176
'--num_epochs', str(num_epochs),
198-
'--style_image_path', style_image_path,
177+
'--temp_save_style_img', temp_save_style_img,
199178
'--batch_size', str(batch_size),
200179
'--content_weight', str(content_weight),
201180
'--style_weight', str(style_weight),
181+
'--tv_weight', str(tv_weight),
202182
'--adam_lr', str(learning_rate),
203183
'--save_model_path', save_model_path,
204184
'--save_image_path', save_image_path,
205185
'--save_model_every', str(save_model_every),
206-
'--seed', str(seed)
186+
'--seed', str(seed),
187+
'--pretrained_model'
207188
]
208189

190+
if from_pretrained:
191+
command.append(model)
192+
else:
193+
command.append('none')
194+
195+
196+
torch.save(self.encode_tensor(style_img), temp_save_style_img)
197+
209198
sp.run(command)
210199
return ()
211200

212201

213-
214-
215-
216202
class FastStyleTransfer:
217203
def __init__(self):
218204
pass
@@ -262,15 +248,81 @@ def styleTransfer(self, content_img, model):
262248
return (image,)
263249

264250

251+
class NeuralStyleTransfer:
252+
253+
254+
255+
@classmethod
256+
def INPUT_TYPES(s):
257+
return {
258+
"required": {
259+
"content_img": ("IMAGE",),
260+
"style_img": ("IMAGE",),
261+
"content_weight": ("FLOAT", {"default": 1e5, "min": 1e3, "max": 1e6, "step": 1e3}),
262+
"style_weight": ("FLOAT", {"default": 3e4, "min": 1e1, "max": 1e5, "step": 1e1}),
263+
"tv_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1e1, "step": 0.1}),
264+
"num_steps": ("INT", {"default": 100, "min": 10, "max": 10000, "step": 10}),
265+
"learning_rate": ("FLOAT", {"default": 1.0, "min": 1e-4, "max": 1e3, "step": 0.1}),
266+
},
267+
}
268+
269+
RETURN_TYPES = ("IMAGE",)
270+
FUNCTION = "neural_style_transfer"
271+
CATEGORY = "Style Transfer"
272+
273+
def encode_tensor(self, tensor):
274+
tensor = tensor.permute(0, 3, 1, 2).contiguous() # Convert to [batch_size, channels, height, width]
275+
return tensor * 255
276+
277+
def decode_tensor(self, tensor):
278+
tensor = tensor.permute(0, 2, 3, 1).contiguous() # Convert to [batch_size, height, width, channels]
279+
return tensor / 255
280+
281+
282+
def neural_style_transfer(self, content_img, style_img, content_weight, style_weight, tv_weight, num_steps, learning_rate):
283+
284+
neural_style_transfer_path = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/neural_style_transfer.py")
285+
286+
temp_save_content_img = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/temp/") + "temp_save_content_img.pt"
287+
temp_save_style_img = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/temp/") + "temp_save_style_img.pt"
288+
289+
temp_load_final_img = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/temp/") + "temp_load_final_img.pt"
290+
291+
torch.save(self.encode_tensor(content_img), temp_save_content_img)
292+
torch.save(self.encode_tensor(style_img), temp_save_style_img)
293+
294+
295+
command = [
296+
sys.executable, neural_style_transfer_path,
297+
'--content_weight', str(content_weight),
298+
'--style_weight', str(style_weight),
299+
'--tv_weight', str(tv_weight),
300+
'--temp_save_style_img', temp_save_style_img,
301+
'--temp_save_content_img', temp_save_content_img,
302+
'--temp_load_final_img', temp_load_final_img,
303+
'--num_steps', str(num_steps),
304+
'--learning_rate', str(learning_rate)
305+
]
306+
307+
sp.run(command)
308+
309+
image = self.decode_tensor(torch.load(temp_load_final_img))
310+
os.remove(temp_save_style_img)
311+
os.remove(temp_save_content_img)
312+
os.remove(temp_load_final_img)
313+
return (image,)
314+
265315
# A dictionary that contains all nodes you want to export with their names
266316
# NOTE: names should be globally unique
267317
NODE_CLASS_MAPPINGS = {
268318
"FastStyleTransfer": FastStyleTransfer,
269-
"TrainFastStyleTransfer": TrainFastStyleTransfer
319+
"TrainFastStyleTransfer": TrainFastStyleTransfer,
320+
"NeuralStyleTransfer": NeuralStyleTransfer,
270321
}
271322

272323
# A dictionary that contains the friendly/humanly readable titles for the nodes
273324
NODE_DISPLAY_NAME_MAPPINGS = {
274325
"FastStyleTransfer": "Fast Style Transfer",
275-
"TrainFastStyleTransfer": "Train Fast Style Transfer"
326+
"TrainFastStyleTransfer": "Train Fast Style Transfer",
327+
"NeuralStyleTransfer": "Neural Style Transfer",
276328
}

0 commit comments

Comments
 (0)