Closed
Description
drafted
https://github.com/johndpope/MegaPortrait-hack/tree/feat/sub-sampling
i dont see a massive speed up.
it's possible this could be randomized - sometimes go full / half / quarter etc.
class PerceptualLoss(nn.Module):
def __init__(self, device, weights={'vgg19': 20.0, 'vggface':5.0, 'gaze': 4.0}):
super(PerceptualLoss, self).__init__()
self.device = device
self.weights = weights
# VGG19 network
vgg19 = models.vgg19(pretrained=True).features
self.vgg19 = nn.Sequential(*[vgg19[i] for i in range(30)]).to(device).eval()
self.vgg19_layers = [1, 6, 11, 20, 29]
# VGGFace network
self.vggface = InceptionResnetV1(pretrained='vggface2').to(device).eval()
self.vggface_layers = [4, 5, 6, 7]
# Gaze loss
self.gaze_loss = MPGazeLoss(device)
# Trick shot to reduce memory 3.3 - use random sub_sample
# https://arxiv.org/pdf/2404.09736#page=5.58
def forward(self, predicted, target, sub_sample_size=(128, 128),use_fm_loss=False):
# Normalize input images
predicted = self.normalize_input(predicted)
target = self.normalize_input(target)
# Compute VGG19 perceptual loss
vgg19_loss = self.compute_vgg19_loss(predicted, target)
# Compute VGGFace perceptual loss
vggface_loss = self.compute_vggface_loss(predicted, target)
# Compute gaze loss
# gaze_loss = self.gaze_loss(predicted, target)
# Compute total perceptual loss
total_loss = (
self.weights['vgg19'] * vgg19_loss +
self.weights['vggface'] * vggface_loss +
self.weights['gaze'] * 1 #gaze_loss
)
if use_fm_loss:
# Compute feature matching loss
fm_loss = self.compute_feature_matching_loss(predicted, target)
total_loss += fm_loss
return total_loss
def sub_sample_tensor(self, tensor, sub_sample_size):
assert tensor.ndim == 4, "Input tensor should have 4 dimensions (batch_size, channels, height, width)"
assert tensor.shape[-2] >= sub_sample_size[0] and tensor.shape[-1] >= sub_sample_size[1], "Sub-sample size should not exceed the tensor dimensions"
batch_size, channels, height, width = tensor.shape
# randomly sample so we cover all the image over training.
random_offset_x = np.random.randint(0, height - sub_sample_size[0])
random_offset_y = np.random.randint(0, width - sub_sample_size[1])
sub_sampled_tensor = tensor[..., random_offset_x:random_offset_x+sub_sample_size[0], random_offset_y:random_offset_y+sub_sample_size[1]]
return sub_sampled_tensor
def compute_vgg19_loss(self, predicted, target):
return self.compute_perceptual_loss(self.vgg19, self.vgg19_layers, predicted, target)
def compute_vggface_loss(self, predicted, target):
return self.compute_perceptual_loss(self.vggface, self.vggface_layers, predicted, target)
def compute_feature_matching_loss(self, predicted, target):
return self.compute_perceptual_loss(self.vgg19, self.vgg19_layers, predicted, target, detach=True)
def compute_perceptual_loss(self, model, layers, predicted, target, detach=False):
loss = 0.0
predicted_features = predicted
target_features = target
#print(f"predicted_features:{predicted_features.shape}")
#print(f"target_features:{target_features.shape}")
for i, layer in enumerate(model.children()):
# print(f"i{i}")
if isinstance(layer, nn.Conv2d):
predicted_features = layer(predicted_features)
target_features = layer(target_features)
elif isinstance(layer, nn.Linear):
predicted_features = predicted_features.view(predicted_features.size(0), -1)
target_features = target_features.view(target_features.size(0), -1)
predicted_features = layer(predicted_features)
target_features = layer(target_features)
else:
predicted_features = layer(predicted_features)
target_features = layer(target_features)
if i in layers:
if detach:
loss += torch.mean(torch.abs(predicted_features - target_features.detach()))
else:
loss += torch.mean(torch.abs(predicted_features - target_features))
return loss
def normalize_input(self, x):
mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1)
return (x - mean) / std
Metadata
Metadata
Assignees
Labels
No labels