From c390bd329c33b0d95458808ec0db0e827555f711 Mon Sep 17 00:00:00 2001 From: Sanster Date: Wed, 29 Dec 2021 11:01:44 +0800 Subject: [PATCH] add torchscript convert script --- bin/to_jit.py | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 bin/to_jit.py diff --git a/bin/to_jit.py b/bin/to_jit.py new file mode 100644 index 0000000..8acea39 --- /dev/null +++ b/bin/to_jit.py @@ -0,0 +1,75 @@ +import os +from pathlib import Path + +import hydra +import torch +import yaml +from omegaconf import OmegaConf +from torch import nn + +from saicinpainting.training.trainers import load_checkpoint +from saicinpainting.utils import register_debug_signal_handlers + + +class JITWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, image, mask): + batch = { + "image": image, + "mask": mask + } + out = self.model(batch) + return out["inpainted"] + + +@hydra.main(config_path="../configs/prediction", config_name="default.yaml") +def main(predict_config: OmegaConf): + register_debug_signal_handlers() # kill -10 will result in traceback dumped into log + + train_config_path = os.path.join(predict_config.model.path, "config.yaml") + with open(train_config_path, "r") as f: + train_config = OmegaConf.create(yaml.safe_load(f)) + + train_config.training_model.predict_only = True + train_config.visualizer.kind = "noop" + + checkpoint_path = os.path.join( + predict_config.model.path, "models", predict_config.model.checkpoint + ) + model = load_checkpoint( + train_config, checkpoint_path, strict=False, map_location="cpu" + ) + model.eval() + jit_model_wrapper = JITWrapper(model) + + image = torch.rand(1, 3, 120, 120) + mask = torch.rand(1, 1, 120, 120) + output = jit_model_wrapper(image, mask) + + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + + image = image.to(device) + mask = mask.to(device) + traced_model = torch.jit.trace(jit_model_wrapper, (image, mask), strict=False).to(device) + + save_path = Path(predict_config.save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + + print(f"Saving big-lama.pt model to {save_path}") + traced_model.save(save_path) + + print(f"Checking jit model output...") + jit_model = torch.jit.load(str(save_path)) + jit_output = jit_model(image, mask) + diff = (output - jit_output).abs().sum() + print(f"diff: {diff}") + + +if __name__ == "__main__": + main()