image-inpaintingcolab-notebookhigh-resolutioncolabgenerative-adversarial-networkscnngenerative-adversarial-networkganfourier-transformfourier-convolutionspytorchfourierinpainting-methodsdeep-neural-networksinpainting-algorithmdeep-learninginpaintingcomputer-vision
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
76 lines
2.1 KiB
76 lines
2.1 KiB
3 years ago
|
import os
|
||
|
from pathlib import Path
|
||
|
|
||
|
import hydra
|
||
|
import torch
|
||
|
import yaml
|
||
|
from omegaconf import OmegaConf
|
||
|
from torch import nn
|
||
|
|
||
|
from saicinpainting.training.trainers import load_checkpoint
|
||
|
from saicinpainting.utils import register_debug_signal_handlers
|
||
|
|
||
|
|
||
|
class JITWrapper(nn.Module):
|
||
|
def __init__(self, model):
|
||
|
super().__init__()
|
||
|
self.model = model
|
||
|
|
||
|
def forward(self, image, mask):
|
||
|
batch = {
|
||
|
"image": image,
|
||
|
"mask": mask
|
||
|
}
|
||
|
out = self.model(batch)
|
||
|
return out["inpainted"]
|
||
|
|
||
|
|
||
|
@hydra.main(config_path="../configs/prediction", config_name="default.yaml")
|
||
|
def main(predict_config: OmegaConf):
|
||
|
register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
|
||
|
|
||
|
train_config_path = os.path.join(predict_config.model.path, "config.yaml")
|
||
|
with open(train_config_path, "r") as f:
|
||
|
train_config = OmegaConf.create(yaml.safe_load(f))
|
||
|
|
||
|
train_config.training_model.predict_only = True
|
||
|
train_config.visualizer.kind = "noop"
|
||
|
|
||
|
checkpoint_path = os.path.join(
|
||
|
predict_config.model.path, "models", predict_config.model.checkpoint
|
||
|
)
|
||
|
model = load_checkpoint(
|
||
|
train_config, checkpoint_path, strict=False, map_location="cpu"
|
||
|
)
|
||
|
model.eval()
|
||
|
jit_model_wrapper = JITWrapper(model)
|
||
|
|
||
|
image = torch.rand(1, 3, 120, 120)
|
||
|
mask = torch.rand(1, 1, 120, 120)
|
||
|
output = jit_model_wrapper(image, mask)
|
||
|
|
||
|
if torch.cuda.is_available():
|
||
|
device = torch.device("cuda")
|
||
|
else:
|
||
|
device = torch.device("cpu")
|
||
|
|
||
|
image = image.to(device)
|
||
|
mask = mask.to(device)
|
||
|
traced_model = torch.jit.trace(jit_model_wrapper, (image, mask), strict=False).to(device)
|
||
|
|
||
|
save_path = Path(predict_config.save_path)
|
||
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
||
|
print(f"Saving big-lama.pt model to {save_path}")
|
||
|
traced_model.save(save_path)
|
||
|
|
||
|
print(f"Checking jit model output...")
|
||
|
jit_model = torch.jit.load(str(save_path))
|
||
|
jit_output = jit_model(image, mask)
|
||
|
diff = (output - jit_output).abs().sum()
|
||
|
print(f"diff: {diff}")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|