We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
drafted https://github.com/johndpope/MegaPortrait-hack/tree/feat/sub-sampling
i dont see a massive speed up.
it's possible this could be randomized - sometimes go full / half / quarter etc.
class PerceptualLoss(nn.Module): def __init__(self, device, weights={'vgg19': 20.0, 'vggface':5.0, 'gaze': 4.0}): super(PerceptualLoss, self).__init__() self.device = device self.weights = weights # VGG19 network vgg19 = models.vgg19(pretrained=True).features self.vgg19 = nn.Sequential(*[vgg19[i] for i in range(30)]).to(device).eval() self.vgg19_layers = [1, 6, 11, 20, 29] # VGGFace network self.vggface = InceptionResnetV1(pretrained='vggface2').to(device).eval() self.vggface_layers = [4, 5, 6, 7] # Gaze loss self.gaze_loss = MPGazeLoss(device) # Trick shot to reduce memory 3.3 - use random sub_sample # https://arxiv.org/pdf/2404.09736#page=5.58 def forward(self, predicted, target, sub_sample_size=(128, 128),use_fm_loss=False): # Normalize input images predicted = self.normalize_input(predicted) target = self.normalize_input(target) # Compute VGG19 perceptual loss vgg19_loss = self.compute_vgg19_loss(predicted, target) # Compute VGGFace perceptual loss vggface_loss = self.compute_vggface_loss(predicted, target) # Compute gaze loss # gaze_loss = self.gaze_loss(predicted, target) # Compute total perceptual loss total_loss = ( self.weights['vgg19'] * vgg19_loss + self.weights['vggface'] * vggface_loss + self.weights['gaze'] * 1 #gaze_loss ) if use_fm_loss: # Compute feature matching loss fm_loss = self.compute_feature_matching_loss(predicted, target) total_loss += fm_loss return total_loss def sub_sample_tensor(self, tensor, sub_sample_size): assert tensor.ndim == 4, "Input tensor should have 4 dimensions (batch_size, channels, height, width)" assert tensor.shape[-2] >= sub_sample_size[0] and tensor.shape[-1] >= sub_sample_size[1], "Sub-sample size should not exceed the tensor dimensions" batch_size, channels, height, width = tensor.shape # randomly sample so we cover all the image over training. random_offset_x = np.random.randint(0, height - sub_sample_size[0]) random_offset_y = np.random.randint(0, width - sub_sample_size[1]) sub_sampled_tensor = tensor[..., random_offset_x:random_offset_x+sub_sample_size[0], random_offset_y:random_offset_y+sub_sample_size[1]] return sub_sampled_tensor def compute_vgg19_loss(self, predicted, target): return self.compute_perceptual_loss(self.vgg19, self.vgg19_layers, predicted, target) def compute_vggface_loss(self, predicted, target): return self.compute_perceptual_loss(self.vggface, self.vggface_layers, predicted, target) def compute_feature_matching_loss(self, predicted, target): return self.compute_perceptual_loss(self.vgg19, self.vgg19_layers, predicted, target, detach=True) def compute_perceptual_loss(self, model, layers, predicted, target, detach=False): loss = 0.0 predicted_features = predicted target_features = target #print(f"predicted_features:{predicted_features.shape}") #print(f"target_features:{target_features.shape}") for i, layer in enumerate(model.children()): # print(f"i{i}") if isinstance(layer, nn.Conv2d): predicted_features = layer(predicted_features) target_features = layer(target_features) elif isinstance(layer, nn.Linear): predicted_features = predicted_features.view(predicted_features.size(0), -1) target_features = target_features.view(target_features.size(0), -1) predicted_features = layer(predicted_features) target_features = layer(target_features) else: predicted_features = layer(predicted_features) target_features = layer(target_features) if i in layers: if detach: loss += torch.mean(torch.abs(predicted_features - target_features.detach())) else: loss += torch.mean(torch.abs(predicted_features - target_features)) return loss def normalize_input(self, x): mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1) std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1) return (x - mean) / std
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Uh oh!
There was an error while loading. Please reload this page.
drafted
https://github.com/johndpope/MegaPortrait-hack/tree/feat/sub-sampling
i dont see a massive speed up.
it's possible this could be randomized - sometimes go full / half / quarter etc.
The text was updated successfully, but these errors were encountered: