Check if using secondary model before loading it

pull/11/head
Nate Baer 3 years ago
parent 92d90c0666
commit 7c5e365f74
  1. 5
      Disco_Diffusion.ipynb
  2. 5
      disco.py

@ -2321,15 +2321,14 @@
" 'use_scale_shift_norm': True,\n", " 'use_scale_shift_norm': True,\n",
" })\n", " })\n",
"\n", "\n",
"secondary_model_ver = 2\n",
"model_default = model_config['image_size']\n", "model_default = model_config['image_size']\n",
"\n", "\n",
"\n", "\n",
"\n", "\n",
"if secondary_model_ver == 2:\n", "if use_secondary_model:\n",
" secondary_model = SecondaryDiffusionImageNet2()\n", " secondary_model = SecondaryDiffusionImageNet2()\n",
" secondary_model.load_state_dict(torch.load(f'{model_path}/secondary_model_imagenet_2.pth', map_location='cpu'))\n", " secondary_model.load_state_dict(torch.load(f'{model_path}/secondary_model_imagenet_2.pth', map_location='cpu'))\n",
"secondary_model.eval().requires_grad_(False).to(device)\n", " secondary_model.eval().requires_grad_(False).to(device)\n",
"\n", "\n",
"clip_models = []\n", "clip_models = []\n",
"if ViTB32 is True: clip_models.append(clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device)) \n", "if ViTB32 is True: clip_models.append(clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device)) \n",

@ -2199,15 +2199,14 @@ elif diffusion_model == '256x256_diffusion_uncond':
'use_scale_shift_norm': True, 'use_scale_shift_norm': True,
}) })
secondary_model_ver = 2
model_default = model_config['image_size'] model_default = model_config['image_size']
if secondary_model_ver == 2: if use_secondary_model:
secondary_model = SecondaryDiffusionImageNet2() secondary_model = SecondaryDiffusionImageNet2()
secondary_model.load_state_dict(torch.load(f'{model_path}/secondary_model_imagenet_2.pth', map_location='cpu')) secondary_model.load_state_dict(torch.load(f'{model_path}/secondary_model_imagenet_2.pth', map_location='cpu'))
secondary_model.eval().requires_grad_(False).to(device) secondary_model.eval().requires_grad_(False).to(device)
clip_models = [] clip_models = []
if ViTB32 is True: clip_models.append(clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device)) if ViTB32 is True: clip_models.append(clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device))

Loading…
Cancel
Save