image-inpaintingcolab-notebookhigh-resolutioncolabgenerative-adversarial-networkscnngenerative-adversarial-networkganfourier-transformfourier-convolutionspytorchfourierinpainting-methodsdeep-neural-networksinpainting-algorithmdeep-learninginpaintingcomputer-vision
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
34 lines
1.3 KiB
34 lines
1.3 KiB
3 years ago
|
from typing import List
|
||
|
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
|
||
|
def masked_l2_loss(pred, target, mask, weight_known, weight_missing):
|
||
|
per_pixel_l2 = F.mse_loss(pred, target, reduction='none')
|
||
|
pixel_weights = mask * weight_missing + (1 - mask) * weight_known
|
||
|
return (pixel_weights * per_pixel_l2).mean()
|
||
|
|
||
|
|
||
|
def masked_l1_loss(pred, target, mask, weight_known, weight_missing):
|
||
|
per_pixel_l1 = F.l1_loss(pred, target, reduction='none')
|
||
|
pixel_weights = mask * weight_missing + (1 - mask) * weight_known
|
||
|
return (pixel_weights * per_pixel_l1).mean()
|
||
|
|
||
|
|
||
|
def feature_matching_loss(fake_features: List[torch.Tensor], target_features: List[torch.Tensor], mask=None):
|
||
|
if mask is None:
|
||
|
res = torch.stack([F.mse_loss(fake_feat, target_feat)
|
||
|
for fake_feat, target_feat in zip(fake_features, target_features)]).mean()
|
||
|
else:
|
||
|
res = 0
|
||
|
norm = 0
|
||
|
for fake_feat, target_feat in zip(fake_features, target_features):
|
||
|
cur_mask = F.interpolate(mask, size=fake_feat.shape[-2:], mode='bilinear', align_corners=False)
|
||
|
error_weights = 1 - cur_mask
|
||
|
cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean()
|
||
|
res = res + cur_val
|
||
|
norm += 1
|
||
|
res = res / norm
|
||
|
return res
|