fourier-transformfourier-convolutionspytorchfourierinpainting-methodsdeep-neural-networksinpainting-algorithmdeep-learninginpaintingcomputer-visionimage-inpaintingcolab-notebookhigh-resolutioncolabgenerative-adversarial-networkscnngenerative-adversarial-networkgan
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
114 lines
4.0 KiB
114 lines
4.0 KiB
3 years ago
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
import torchvision
|
||
|
|
||
|
from models.ade20k import ModelBuilder
|
||
|
from saicinpainting.utils import check_and_warn_input_range
|
||
|
|
||
|
|
||
|
IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None]
|
||
|
IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None]
|
||
|
|
||
|
|
||
|
class PerceptualLoss(nn.Module):
|
||
|
def __init__(self, normalize_inputs=True):
|
||
|
super(PerceptualLoss, self).__init__()
|
||
|
|
||
|
self.normalize_inputs = normalize_inputs
|
||
|
self.mean_ = IMAGENET_MEAN
|
||
|
self.std_ = IMAGENET_STD
|
||
|
|
||
|
vgg = torchvision.models.vgg19(pretrained=True).features
|
||
|
vgg_avg_pooling = []
|
||
|
|
||
|
for weights in vgg.parameters():
|
||
|
weights.requires_grad = False
|
||
|
|
||
|
for module in vgg.modules():
|
||
|
if module.__class__.__name__ == 'Sequential':
|
||
|
continue
|
||
|
elif module.__class__.__name__ == 'MaxPool2d':
|
||
|
vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
|
||
|
else:
|
||
|
vgg_avg_pooling.append(module)
|
||
|
|
||
|
self.vgg = nn.Sequential(*vgg_avg_pooling)
|
||
|
|
||
|
def do_normalize_inputs(self, x):
|
||
|
return (x - self.mean_.to(x.device)) / self.std_.to(x.device)
|
||
|
|
||
|
def partial_losses(self, input, target, mask=None):
|
||
|
check_and_warn_input_range(target, 0, 1, 'PerceptualLoss target in partial_losses')
|
||
|
|
||
|
# we expect input and target to be in [0, 1] range
|
||
|
losses = []
|
||
|
|
||
|
if self.normalize_inputs:
|
||
|
features_input = self.do_normalize_inputs(input)
|
||
|
features_target = self.do_normalize_inputs(target)
|
||
|
else:
|
||
|
features_input = input
|
||
|
features_target = target
|
||
|
|
||
|
for layer in self.vgg[:30]:
|
||
|
|
||
|
features_input = layer(features_input)
|
||
|
features_target = layer(features_target)
|
||
|
|
||
|
if layer.__class__.__name__ == 'ReLU':
|
||
|
loss = F.mse_loss(features_input, features_target, reduction='none')
|
||
|
|
||
|
if mask is not None:
|
||
|
cur_mask = F.interpolate(mask, size=features_input.shape[-2:],
|
||
|
mode='bilinear', align_corners=False)
|
||
|
loss = loss * (1 - cur_mask)
|
||
|
|
||
|
loss = loss.mean(dim=tuple(range(1, len(loss.shape))))
|
||
|
losses.append(loss)
|
||
|
|
||
|
return losses
|
||
|
|
||
|
def forward(self, input, target, mask=None):
|
||
|
losses = self.partial_losses(input, target, mask=mask)
|
||
|
return torch.stack(losses).sum(dim=0)
|
||
|
|
||
|
def get_global_features(self, input):
|
||
|
check_and_warn_input_range(input, 0, 1, 'PerceptualLoss input in get_global_features')
|
||
|
|
||
|
if self.normalize_inputs:
|
||
|
features_input = self.do_normalize_inputs(input)
|
||
|
else:
|
||
|
features_input = input
|
||
|
|
||
|
features_input = self.vgg(features_input)
|
||
|
return features_input
|
||
|
|
||
|
|
||
|
class ResNetPL(nn.Module):
|
||
|
def __init__(self, weight=1,
|
||
|
weights_path=None, arch_encoder='resnet50dilated', segmentation=True):
|
||
|
super().__init__()
|
||
|
self.impl = ModelBuilder.get_encoder(weights_path=weights_path,
|
||
|
arch_encoder=arch_encoder,
|
||
|
arch_decoder='ppm_deepsup',
|
||
|
fc_dim=2048,
|
||
|
segmentation=segmentation)
|
||
|
self.impl.eval()
|
||
|
for w in self.impl.parameters():
|
||
|
w.requires_grad_(False)
|
||
|
|
||
|
self.weight = weight
|
||
|
|
||
|
def forward(self, pred, target):
|
||
|
pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred)
|
||
|
target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target)
|
||
|
|
||
|
pred_feats = self.impl(pred, return_feature_maps=True)
|
||
|
target_feats = self.impl(target, return_feature_maps=True)
|
||
|
|
||
|
result = torch.stack([F.mse_loss(cur_pred, cur_target)
|
||
|
for cur_pred, cur_target
|
||
|
in zip(pred_feats, target_feats)]).sum() * self.weight
|
||
|
return result
|