1
1
"""
2
- This node is a simple conversion of this repository into ComfyUI ecosystem:
2
+ These nodes are a simple conversion of these repositories into ComfyUI ecosystem:
3
3
https://github.com/rrmina/fast-neural-style-pytorch.git
4
+ https://github.com/gordicaleksa/pytorch-neural-style-transfer.git
4
5
5
- Some of the code is written by ChatGPT4-o
6
6
"""
7
7
8
8
import torch
9
9
import torch .nn as nn
10
- from torchvision import models
11
10
import time
12
11
import os
13
12
import folder_paths
14
13
import subprocess as sp
15
14
import sys
16
15
17
-
18
-
19
-
20
- class VGG16 (nn .Module ):
21
- def __init__ (self , vgg_path = os .path .join (folder_paths .base_path , "custom_nodes/ComfyUI-Fast-Style-Transfer/vgg/vgg16-00b39a1b.pth" ), train = False ):
22
- super (VGG16 , self ).__init__ ()
23
- # Load VGG Skeleton, Pretrained Weights
24
- vgg16_features = models .vgg16 (pretrained = False )
25
- vgg16_features .load_state_dict (torch .load (vgg_path ), strict = False )
26
- self .features = vgg16_features .features
27
-
28
- # Turn-off Gradient History (on for testing train)
29
- for param in self .features .parameters ():
30
- param .requires_grad = train
31
-
32
- def forward (self , x ):
33
- layers = {'3' : 'relu1_2' , '8' : 'relu2_2' , '15' : 'relu3_3' , '22' : 'relu4_3' }
34
- features = {}
35
- for name , layer in self .features ._modules .items ():
36
- x = layer (x )
37
- if name in layers :
38
- features [layers [name ]] = x
39
- if (name == '22' ):
40
- break
41
-
42
- return features
43
-
44
-
45
- class ConvLayer (nn .Module ):
16
+ # ML classes
17
+ class ConvolutionalLayer (nn .Module ):
46
18
def __init__ (self , in_channels , out_channels , kernel_size , stride , norm = "instance" ):
47
- super (ConvLayer , self ).__init__ ()
19
+ super (ConvolutionalLayer , self ).__init__ ()
48
20
# Padding Layers
49
21
self .padding_size = kernel_size // 2
50
22
self .reflection_pad = nn .ReflectionPad2d (self .padding_size )
@@ -77,9 +49,9 @@ class ResidualLayer(nn.Module):
77
49
"""
78
50
def __init__ (self , channels = 128 , kernel_size = 3 ):
79
51
super (ResidualLayer , self ).__init__ ()
80
- self .conv1 = ConvLayer (channels , channels , kernel_size , stride = 1 )
52
+ self .conv1 = ConvolutionalLayer (channels , channels , kernel_size , stride = 1 )
81
53
self .relu = nn .ReLU ()
82
- self .conv2 = ConvLayer (channels , channels , kernel_size , stride = 1 )
54
+ self .conv2 = ConvolutionalLayer (channels , channels , kernel_size , stride = 1 )
83
55
84
56
def forward (self , x ):
85
57
identity = x # preserve residual
@@ -89,9 +61,9 @@ def forward(self, x):
89
61
return out
90
62
91
63
92
- class DeconvLayer (nn .Module ):
64
+ class DeconvolutionalLayer (nn .Module ):
93
65
def __init__ (self , in_channels , out_channels , kernel_size , stride , output_padding , norm = "instance" ):
94
- super (DeconvLayer , self ).__init__ ()
66
+ super (DeconvolutionalLayer , self ).__init__ ()
95
67
96
68
# Transposed Convolution
97
69
padding_size = kernel_size // 2
@@ -121,11 +93,11 @@ class TransformerNetwork(nn.Module):
121
93
def __init__ (self ):
122
94
super (TransformerNetwork , self ).__init__ ()
123
95
self .ConvBlock = nn .Sequential (
124
- ConvLayer (3 , 32 , 9 , 1 ),
96
+ ConvolutionalLayer (3 , 32 , 9 , 1 ),
125
97
nn .ReLU (),
126
- ConvLayer (32 , 64 , 3 , 2 ),
98
+ ConvolutionalLayer (32 , 64 , 3 , 2 ),
127
99
nn .ReLU (),
128
- ConvLayer (64 , 128 , 3 , 2 ),
100
+ ConvolutionalLayer (64 , 128 , 3 , 2 ),
129
101
nn .ReLU ()
130
102
)
131
103
self .ResidualBlock = nn .Sequential (
@@ -136,11 +108,11 @@ def __init__(self):
136
108
ResidualLayer (128 , 3 )
137
109
)
138
110
self .DeconvBlock = nn .Sequential (
139
- DeconvLayer (128 , 64 , 3 , 2 , 1 ),
111
+ DeconvolutionalLayer (128 , 64 , 3 , 2 , 1 ),
140
112
nn .ReLU (),
141
- DeconvLayer (64 , 32 , 3 , 2 , 1 ),
113
+ DeconvolutionalLayer (64 , 32 , 3 , 2 , 1 ),
142
114
nn .ReLU (),
143
- ConvLayer (32 , 3 , 9 , 1 , norm = "None" )
115
+ ConvolutionalLayer (32 , 3 , 9 , 1 , norm = "None" )
144
116
)
145
117
146
118
def forward (self , x ):
@@ -150,25 +122,27 @@ def forward(self, x):
150
122
return out
151
123
152
124
125
+ # Node classes
153
126
class TrainFastStyleTransfer :
154
127
def __init__ (self ):
155
128
pass
156
129
157
130
@classmethod
158
131
def INPUT_TYPES (s ):
159
- input_dir = folder_paths .get_input_directory ()
160
- files = [f for f in os .listdir (input_dir ) if os .path .isfile (os .path .join (input_dir , f ))]
161
132
return {
162
133
"required" : {
163
- "style_img" : (sorted ( files ), { "image_upload" : True } ),
134
+ "style_img" : ("IMAGE" , ),
164
135
"seed" : ("INT" , {"default" : 30 , "min" : 0 , "max" : 999999 , "step" : 1 ,}),
165
136
"content_weight" : ("INT" , {"default" : 14 , "min" : 1 , "max" : 128 , "step" : 1 ,}),
166
137
"style_weight" : ("INT" , {"default" : 50 , "min" : 1 , "max" : 128 , "step" : 1 ,}),
167
- "batch_size" : ("INT" , {"default" : 4 , "min" : 1 , "max" : 128 , "step" : 1 ,}),
168
- "train_img_size" : ("INT" , {"default" : 256 , "min" : 256 , "max" : 2048 , "step" : 1 ,}),
169
- "learning_rate" : ("FLOAT" , {"default" : 0.001 , "min" : 0.0001 , "max" : 0.1 , "step" : 0.0001 }),
138
+ "tv_weight" : ("FLOAT" , {"default" : 0.001 , "min" : 0.0 , "max" : 1.0 , "step" : 0.0000001 }),
139
+ "batch_size" : ("INT" , {"default" : 4 , "min" : 1 , "max" : 32 , "step" : 1 ,}),
140
+ "train_img_size" : ("INT" , {"default" : 256 , "min" : 128 , "max" : 2048 , "step" : 1 ,}),
141
+ "learning_rate" : ("FLOAT" , {"default" : 0.001 , "min" : 0.0001 , "max" : 100.0 , "step" : 0.0001 }),
170
142
"num_epochs" : ("INT" , {"default" : 1 , "min" : 1 , "max" : 20 , "step" : 1 ,}),
171
- "save_model_every" : ("INT" , {"default" : 500 , "min" : 100 , "max" : 10000 , "step" : 1 ,}),
143
+ "save_model_every" : ("INT" , {"default" : 500 , "min" : 10 , "max" : 10000 , "step" : 10 ,}),
144
+ "from_pretrained" : ("INT" , {"default" : 0 , "min" : 0 , "max" : 1 , "step" : 1 ,}),
145
+ "model" : ([file for file in os .listdir (os .path .join (folder_paths .base_path , "custom_nodes/ComfyUI-Fast-Style-Transfer/models/" )) if file .endswith ('.pth' )], ),
172
146
},
173
147
}
174
148
@@ -179,14 +153,19 @@ def INPUT_TYPES(s):
179
153
180
154
CATEGORY = "Style Transfer"
181
155
156
+ def encode_tensor (self , tensor ):
157
+ tensor = tensor .permute (0 , 3 , 1 , 2 ).contiguous () # Convert to [batch_size, channels, height, width]
158
+ return tensor [:, [2 , 1 , 0 ], :, :] * 255
182
159
183
- def train (self , style_img , seed , batch_size , train_img_size , learning_rate , num_epochs , content_weight , style_weight , save_model_every ):
160
+
161
+ def train (self , style_img , seed , batch_size , train_img_size , learning_rate , num_epochs , content_weight , style_weight , tv_weight , save_model_every , from_pretrained , model ):
162
+ temp_save_style_img = os .path .join (folder_paths .base_path , "custom_nodes/ComfyUI-Fast-Style-Transfer/temp/" ) + "temp_save_content_img.pt"
184
163
save_model_path = os .path .join (folder_paths .base_path , "custom_nodes/ComfyUI-Fast-Style-Transfer/models/" )
185
164
dataset_path = os .path .join (folder_paths .base_path , "custom_nodes/ComfyUI-Fast-Style-Transfer/dataset/" )
186
165
vgg_path = os .path .join (folder_paths .base_path , "custom_nodes/ComfyUI-Fast-Style-Transfer/vgg/vgg16-00b39a1b.pth" )
187
166
save_image_path = os .path .join (folder_paths .base_path , "custom_nodes/ComfyUI-Fast-Style-Transfer/output/" )
188
- style_image_path = folder_paths .get_annotated_filepath (style_img )
189
167
train_path = os .path .join (folder_paths .base_path , "custom_nodes/ComfyUI-Fast-Style-Transfer/train.py" )
168
+
190
169
191
170
192
171
command = [
@@ -195,24 +174,31 @@ def train(self, style_img, seed, batch_size, train_img_size, learning_rate, num_
195
174
'--dataset_path' , dataset_path ,
196
175
'--vgg_path' , vgg_path ,
197
176
'--num_epochs' , str (num_epochs ),
198
- '--style_image_path ' , style_image_path ,
177
+ '--temp_save_style_img ' , temp_save_style_img ,
199
178
'--batch_size' , str (batch_size ),
200
179
'--content_weight' , str (content_weight ),
201
180
'--style_weight' , str (style_weight ),
181
+ '--tv_weight' , str (tv_weight ),
202
182
'--adam_lr' , str (learning_rate ),
203
183
'--save_model_path' , save_model_path ,
204
184
'--save_image_path' , save_image_path ,
205
185
'--save_model_every' , str (save_model_every ),
206
- '--seed' , str (seed )
186
+ '--seed' , str (seed ),
187
+ '--pretrained_model'
207
188
]
208
189
190
+ if from_pretrained :
191
+ command .append (model )
192
+ else :
193
+ command .append ('none' )
194
+
195
+
196
+ torch .save (self .encode_tensor (style_img ), temp_save_style_img )
197
+
209
198
sp .run (command )
210
199
return ()
211
200
212
201
213
-
214
-
215
-
216
202
class FastStyleTransfer :
217
203
def __init__ (self ):
218
204
pass
@@ -262,15 +248,81 @@ def styleTransfer(self, content_img, model):
262
248
return (image ,)
263
249
264
250
251
+ class NeuralStyleTransfer :
252
+
253
+
254
+
255
+ @classmethod
256
+ def INPUT_TYPES (s ):
257
+ return {
258
+ "required" : {
259
+ "content_img" : ("IMAGE" ,),
260
+ "style_img" : ("IMAGE" ,),
261
+ "content_weight" : ("FLOAT" , {"default" : 1e5 , "min" : 1e3 , "max" : 1e6 , "step" : 1e3 }),
262
+ "style_weight" : ("FLOAT" , {"default" : 3e4 , "min" : 1e1 , "max" : 1e5 , "step" : 1e1 }),
263
+ "tv_weight" : ("FLOAT" , {"default" : 1.0 , "min" : 0.0 , "max" : 1e1 , "step" : 0.1 }),
264
+ "num_steps" : ("INT" , {"default" : 100 , "min" : 10 , "max" : 10000 , "step" : 10 }),
265
+ "learning_rate" : ("FLOAT" , {"default" : 1.0 , "min" : 1e-4 , "max" : 1e3 , "step" : 0.1 }),
266
+ },
267
+ }
268
+
269
+ RETURN_TYPES = ("IMAGE" ,)
270
+ FUNCTION = "neural_style_transfer"
271
+ CATEGORY = "Style Transfer"
272
+
273
+ def encode_tensor (self , tensor ):
274
+ tensor = tensor .permute (0 , 3 , 1 , 2 ).contiguous () # Convert to [batch_size, channels, height, width]
275
+ return tensor * 255
276
+
277
+ def decode_tensor (self , tensor ):
278
+ tensor = tensor .permute (0 , 2 , 3 , 1 ).contiguous () # Convert to [batch_size, height, width, channels]
279
+ return tensor / 255
280
+
281
+
282
+ def neural_style_transfer (self , content_img , style_img , content_weight , style_weight , tv_weight , num_steps , learning_rate ):
283
+
284
+ neural_style_transfer_path = os .path .join (folder_paths .base_path , "custom_nodes/ComfyUI-Fast-Style-Transfer/neural_style_transfer.py" )
285
+
286
+ temp_save_content_img = os .path .join (folder_paths .base_path , "custom_nodes/ComfyUI-Fast-Style-Transfer/temp/" ) + "temp_save_content_img.pt"
287
+ temp_save_style_img = os .path .join (folder_paths .base_path , "custom_nodes/ComfyUI-Fast-Style-Transfer/temp/" ) + "temp_save_style_img.pt"
288
+
289
+ temp_load_final_img = os .path .join (folder_paths .base_path , "custom_nodes/ComfyUI-Fast-Style-Transfer/temp/" ) + "temp_load_final_img.pt"
290
+
291
+ torch .save (self .encode_tensor (content_img ), temp_save_content_img )
292
+ torch .save (self .encode_tensor (style_img ), temp_save_style_img )
293
+
294
+
295
+ command = [
296
+ sys .executable , neural_style_transfer_path ,
297
+ '--content_weight' , str (content_weight ),
298
+ '--style_weight' , str (style_weight ),
299
+ '--tv_weight' , str (tv_weight ),
300
+ '--temp_save_style_img' , temp_save_style_img ,
301
+ '--temp_save_content_img' , temp_save_content_img ,
302
+ '--temp_load_final_img' , temp_load_final_img ,
303
+ '--num_steps' , str (num_steps ),
304
+ '--learning_rate' , str (learning_rate )
305
+ ]
306
+
307
+ sp .run (command )
308
+
309
+ image = self .decode_tensor (torch .load (temp_load_final_img ))
310
+ os .remove (temp_save_style_img )
311
+ os .remove (temp_save_content_img )
312
+ os .remove (temp_load_final_img )
313
+ return (image ,)
314
+
265
315
# A dictionary that contains all nodes you want to export with their names
266
316
# NOTE: names should be globally unique
267
317
NODE_CLASS_MAPPINGS = {
268
318
"FastStyleTransfer" : FastStyleTransfer ,
269
- "TrainFastStyleTransfer" : TrainFastStyleTransfer
319
+ "TrainFastStyleTransfer" : TrainFastStyleTransfer ,
320
+ "NeuralStyleTransfer" : NeuralStyleTransfer ,
270
321
}
271
322
272
323
# A dictionary that contains the friendly/humanly readable titles for the nodes
273
324
NODE_DISPLAY_NAME_MAPPINGS = {
274
325
"FastStyleTransfer" : "Fast Style Transfer" ,
275
- "TrainFastStyleTransfer" : "Train Fast Style Transfer"
326
+ "TrainFastStyleTransfer" : "Train Fast Style Transfer" ,
327
+ "NeuralStyleTransfer" : "Neural Style Transfer" ,
276
328
}
0 commit comments