remove slip

pull/26/head
MSFTserver 3 years ago
parent a1f25c79ec
commit c509aa1b9c
  1. 21
      disco.py

@ -246,6 +246,8 @@ if skip_for_run_all == False:
Remove Super Resolution
Remove SLIP Models
'''
)
@ -439,7 +441,6 @@ model_secondary_downloaded = False
if is_colab:
gitclone("https://github.com/openai/CLIP")
#gitclone("https://github.com/facebookresearch/SLIP.git")
gitclone("https://github.com/crowsonkb/guided-diffusion")
gitclone("https://github.com/assafshocher/ResizeRight.git")
gitclone("https://github.com/MSFTserver/pytorch3d-lite.git")
@ -468,7 +469,6 @@ if not os.path.exists(f'{model_path}/dpt_large-midas-2f21e586.pt'):
import sys
import torch
# sys.path.append('./SLIP')
sys.path.append('./pytorch3d-lite')
sys.path.append('./ResizeRight')
sys.path.append('./MiDaS')
@ -496,7 +496,6 @@ sys.path.append('./CLIP')
sys.path.append('./guided-diffusion')
import clip
from resize_right import resize
# from models import SLIP_VITB16, SLIP, SLIP_VITL16
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
from datetime import datetime
import numpy as np
@ -1636,8 +1635,6 @@ RN50 = True #@param{type:"boolean"}
RN50x4 = False #@param{type:"boolean"}
RN50x16 = False #@param{type:"boolean"}
RN50x64 = False #@param{type:"boolean"}
SLIPB16 = False #@param{type:"boolean"}
SLIPL16 = False #@param{type:"boolean"}
#@markdown If you're having issues with model downloads, check this to compare SHA's:
check_model_SHA = False #@param{type:"boolean"}
@ -1771,20 +1768,6 @@ if RN50x16 is True: clip_models.append(clip.load('RN50x16', jit=False)[0].eval()
if RN50x64 is True: clip_models.append(clip.load('RN50x64', jit=False)[0].eval().requires_grad_(False).to(device))
if RN101 is True: clip_models.append(clip.load('RN101', jit=False)[0].eval().requires_grad_(False).to(device))
if SLIPB16:
SLIPB16model = SLIP_VITB16(ssl_mlp_dim=4096, ssl_emb_dim=256)
if not os.path.exists(f'{model_path}/slip_base_100ep.pt'):
wget("https://dl.fbaipublicfiles.com/slip/slip_base_100ep.pt", model_path)
sd = torch.load(f'{model_path}/slip_base_100ep.pt')
real_sd = {}
for k, v in sd['state_dict'].items():
real_sd['.'.join(k.split('.')[1:])] = v
del sd
SLIPB16model.load_state_dict(real_sd)
SLIPB16model.requires_grad_(False).eval().to(device)
clip_models.append(SLIPL16model)
normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
lpips_model = lpips.LPIPS(net='vgg').to(device)

Loading…
Cancel
Save