image-inpaintingcolab-notebookhigh-resolutioncolabgenerative-adversarial-networkscnngenerative-adversarial-networkganfourier-transformfourier-convolutionspytorchfourierinpainting-methodsdeep-neural-networksinpainting-algorithmdeep-learninginpaintingcomputer-vision
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
125 lines
5.2 KiB
125 lines
5.2 KiB
3 years ago
|
#!/usr/bin/env python3
|
||
|
|
||
|
import glob
|
||
|
import os
|
||
|
import shutil
|
||
|
import traceback
|
||
|
import hydra
|
||
|
from omegaconf import OmegaConf
|
||
|
|
||
|
import PIL.Image as Image
|
||
|
import numpy as np
|
||
|
from joblib import Parallel, delayed
|
||
|
|
||
|
from saicinpainting.evaluation.masks.mask import SegmentationMask, propose_random_square_crop
|
||
|
from saicinpainting.evaluation.utils import load_yaml, SmallMode
|
||
|
from saicinpainting.training.data.masks import MixedMaskGenerator
|
||
|
|
||
|
|
||
|
class MakeManyMasksWrapper:
|
||
|
def __init__(self, impl, variants_n=2):
|
||
|
self.impl = impl
|
||
|
self.variants_n = variants_n
|
||
|
|
||
|
def get_masks(self, img):
|
||
|
img = np.transpose(np.array(img), (2, 0, 1))
|
||
|
return [self.impl(img)[0] for _ in range(self.variants_n)]
|
||
|
|
||
|
|
||
|
def process_images(src_images, indir, outdir, config):
|
||
|
if config.generator_kind == 'segmentation':
|
||
|
mask_generator = SegmentationMask(**config.mask_generator_kwargs)
|
||
|
elif config.generator_kind == 'random':
|
||
|
mask_generator_kwargs = OmegaConf.to_container(config.mask_generator_kwargs, resolve=True)
|
||
|
variants_n = mask_generator_kwargs.pop('variants_n', 2)
|
||
|
mask_generator = MakeManyMasksWrapper(MixedMaskGenerator(**mask_generator_kwargs),
|
||
|
variants_n=variants_n)
|
||
|
else:
|
||
|
raise ValueError(f'Unexpected generator kind: {config.generator_kind}')
|
||
|
|
||
|
max_tamper_area = config.get('max_tamper_area', 1)
|
||
|
|
||
|
for infile in src_images:
|
||
|
try:
|
||
|
file_relpath = infile[len(indir):]
|
||
|
img_outpath = os.path.join(outdir, file_relpath)
|
||
|
os.makedirs(os.path.dirname(img_outpath), exist_ok=True)
|
||
|
|
||
|
image = Image.open(infile).convert('RGB')
|
||
|
|
||
|
# scale input image to output resolution and filter smaller images
|
||
|
if min(image.size) < config.cropping.out_min_size:
|
||
|
handle_small_mode = SmallMode(config.cropping.handle_small_mode)
|
||
|
if handle_small_mode == SmallMode.DROP:
|
||
|
continue
|
||
|
elif handle_small_mode == SmallMode.UPSCALE:
|
||
|
factor = config.cropping.out_min_size / min(image.size)
|
||
|
out_size = (np.array(image.size) * factor).round().astype('uint32')
|
||
|
image = image.resize(out_size, resample=Image.BICUBIC)
|
||
|
else:
|
||
|
factor = config.cropping.out_min_size / min(image.size)
|
||
|
out_size = (np.array(image.size) * factor).round().astype('uint32')
|
||
|
image = image.resize(out_size, resample=Image.BICUBIC)
|
||
|
|
||
|
# generate and select masks
|
||
|
src_masks = mask_generator.get_masks(image)
|
||
|
|
||
|
filtered_image_mask_pairs = []
|
||
|
for cur_mask in src_masks:
|
||
|
if config.cropping.out_square_crop:
|
||
|
(crop_left,
|
||
|
crop_top,
|
||
|
crop_right,
|
||
|
crop_bottom) = propose_random_square_crop(cur_mask,
|
||
|
min_overlap=config.cropping.crop_min_overlap)
|
||
|
cur_mask = cur_mask[crop_top:crop_bottom, crop_left:crop_right]
|
||
|
cur_image = image.copy().crop((crop_left, crop_top, crop_right, crop_bottom))
|
||
|
else:
|
||
|
cur_image = image
|
||
|
|
||
|
if len(np.unique(cur_mask)) == 0 or cur_mask.mean() > max_tamper_area:
|
||
|
continue
|
||
|
|
||
|
filtered_image_mask_pairs.append((cur_image, cur_mask))
|
||
|
|
||
|
mask_indices = np.random.choice(len(filtered_image_mask_pairs),
|
||
|
size=min(len(filtered_image_mask_pairs), config.max_masks_per_image),
|
||
|
replace=False)
|
||
|
|
||
|
# crop masks; save masks together with input image
|
||
|
mask_basename = os.path.join(outdir, os.path.splitext(file_relpath)[0])
|
||
|
for i, idx in enumerate(mask_indices):
|
||
|
cur_image, cur_mask = filtered_image_mask_pairs[idx]
|
||
|
cur_basename = mask_basename + f'_crop{i:03d}'
|
||
|
Image.fromarray(np.clip(cur_mask * 255, 0, 255).astype('uint8'),
|
||
|
mode='L').save(cur_basename + f'_mask{i:03d}.png')
|
||
|
cur_image.save(cur_basename + '.png')
|
||
|
except KeyboardInterrupt:
|
||
|
return
|
||
|
except Exception as ex:
|
||
|
print(f'Could not make masks for {infile} due to {ex}:\n{traceback.format_exc()}')
|
||
|
|
||
|
|
||
|
@hydra.main(config_path='../configs/data_gen/whydra', config_name='random_medium_256.yaml')
|
||
|
def main(config: OmegaConf):
|
||
|
if not config.indir.endswith('/'):
|
||
|
config.indir += '/'
|
||
|
|
||
|
os.makedirs(config.outdir, exist_ok=True)
|
||
|
|
||
|
in_files = list(glob.glob(os.path.join(config.indir, '**', f'*.{config.location.extension}'),
|
||
|
recursive=True))
|
||
|
if config.n_jobs == 0:
|
||
|
process_images(in_files, config.indir, config.outdir, config)
|
||
|
else:
|
||
|
in_files_n = len(in_files)
|
||
|
chunk_size = in_files_n // config.n_jobs + (1 if in_files_n % config.n_jobs > 0 else 0)
|
||
|
Parallel(n_jobs=config.n_jobs)(
|
||
|
delayed(process_images)(in_files[start:start+chunk_size], config.indir, config.outdir, config)
|
||
|
for start in range(0, len(in_files), chunk_size)
|
||
|
)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|