add fallback model URLS

pull/67/head
MSFTserver 2 years ago
parent af10d0d737
commit 1fceed19b7
  1. 45
      Disco_Diffusion.ipynb
  2. 45
      disco.py

@ -478,10 +478,6 @@
" root_path = os.getcwd()\n",
" model_path = f'{root_path}/models'\n",
"\n",
"model_256_downloaded = False\n",
"model_512_downloaded = False\n",
"model_secondary_downloaded = False\n",
"\n",
"multipip_res = subprocess.run(['pip', 'install', 'lpips', 'datetime', 'timm', 'ftfy', 'einops', 'pytorch-lightning', 'omegaconf'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
"print(multipip_res)\n",
"\n",
@ -1744,6 +1740,11 @@
"#@markdown If you're having issues with model downloads, check this to compare SHA's:\n",
"check_model_SHA = False #@param{type:\"boolean\"}\n",
"\n",
"def download_models(diffusion_model,use_secondary_model,fallback=False):\n",
" model_256_downloaded = False\n",
" model_512_downloaded = False\n",
" model_secondary_downloaded = False\n",
"\n",
" model_256_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'\n",
" model_512_SHA = '9c111ab89e214862b76e1fa6a1b3f1d329b1a88281885943d2cdbe357ad57648'\n",
" model_secondary_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'\n",
@ -1752,10 +1753,18 @@
" model_512_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/512x512_diffusion_uncond_finetune_008100.pt'\n",
" model_secondary_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth'\n",
"\n",
" model_256_link_fb = 'https://www.dropbox.com/s/9tqnqo930mpnpcn/256x256_diffusion_uncond.pt'\n",
" model_512_link_fb = 'https://www.dropbox.com/s/yjqvhu6l6l0r2eh/512x512_diffusion_uncond_finetune_008100.pt'\n",
" model_secondary_link_fb = 'https://www.dropbox.com/s/luv4fezod3r8d2n/secondary_model_imagenet_2.pth'\n",
"\n",
" model_256_path = f'{model_path}/256x256_diffusion_uncond.pt'\n",
" model_512_path = f'{model_path}/512x512_diffusion_uncond_finetune_008100.pt'\n",
" model_secondary_path = f'{model_path}/secondary_model_imagenet_2.pth'\n",
"\n",
" if fallback:\n",
" model_256_link = model_256_link_fb\n",
" model_512_link = model_512_link_fb\n",
" model_secondary_link = model_secondary_link_fb\n",
" # Download the diffusion model\n",
" if diffusion_model == '256x256_diffusion_uncond':\n",
" if os.path.exists(model_256_path) and check_model_SHA:\n",
@ -1769,12 +1778,20 @@
" else: \n",
" print(\"256 Model SHA doesn't match, redownloading...\")\n",
" wget(model_256_link, model_path)\n",
" if os.path.exists(model_256_path):\n",
" model_256_downloaded = True\n",
" else:\n",
" print('First URL Failed using FallBack')\n",
" download_models(diffusion_model,use_secondary_model,True)\n",
" elif os.path.exists(model_256_path) and not check_model_SHA or model_256_downloaded == True:\n",
" print('256 Model already downloaded, check check_model_SHA if the file is corrupt')\n",
" else: \n",
" wget(model_256_link, model_path)\n",
" if os.path.exists(model_256_path):\n",
" model_256_downloaded = True\n",
" else:\n",
" print('First URL Failed using FallBack')\n",
" download_models(diffusion_model,True)\n",
" elif diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n",
" if os.path.exists(model_512_path) and check_model_SHA:\n",
" print('Checking 512 Diffusion File')\n",
@ -1783,18 +1800,24 @@
" hash = hashlib.sha256(bytes).hexdigest();\n",
" if hash == model_512_SHA:\n",
" print('512 Model SHA matches')\n",
" if os.path.exists(model_512_path):\n",
" model_512_downloaded = True\n",
" else:\n",
" print('First URL Failed using FallBack')\n",
" download_models(diffusion_model,use_secondary_model,True)\n",
" else: \n",
" print(\"512 Model SHA doesn't match, redownloading...\")\n",
" wget(model_512_link, model_path)\n",
" if os.path.exists(model_512_path):\n",
" model_512_downloaded = True\n",
" else:\n",
" print('First URL Failed using FallBack')\n",
" download_models(diffusion_model,use_secondary_model,True)\n",
" elif os.path.exists(model_512_path) and not check_model_SHA or model_512_downloaded == True:\n",
" print('512 Model already downloaded, check check_model_SHA if the file is corrupt')\n",
" else: \n",
" wget(model_512_link, model_path)\n",
" model_512_downloaded = True\n",
"\n",
"\n",
" # Download the secondary diffusion model v2\n",
" if use_secondary_model == True:\n",
" if os.path.exists(model_secondary_path) and check_model_SHA:\n",
@ -1808,12 +1831,22 @@
" else: \n",
" print(\"Secondary Model SHA doesn't match, redownloading...\")\n",
" wget(model_secondary_link, model_path)\n",
" if os.path.exists(model_secondary_path):\n",
" model_secondary_downloaded = True\n",
" else:\n",
" print('First URL Failed using FallBack')\n",
" download_models(diffusion_model,use_secondary_model,True)\n",
" elif os.path.exists(model_secondary_path) and not check_model_SHA or model_secondary_downloaded == True:\n",
" print('Secondary Model already downloaded, check check_model_SHA if the file is corrupt')\n",
" else: \n",
" wget(model_secondary_link, model_path)\n",
" if os.path.exists(model_secondary_path):\n",
" model_secondary_downloaded = True\n",
" else:\n",
" print('First URL Failed using FallBack')\n",
" download_models(diffusion_model,use_secondary_model,True)\n",
"\n",
"download_models(diffusion_model,use_secondary_model)\n",
"\n",
"model_config = model_and_diffusion_defaults()\n",
"if diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n",

@ -451,10 +451,6 @@ else:
root_path = os.getcwd()
model_path = f'{root_path}/models'
model_256_downloaded = False
model_512_downloaded = False
model_secondary_downloaded = False
multipip_res = subprocess.run(['pip', 'install', 'lpips', 'datetime', 'timm', 'ftfy', 'einops', 'pytorch-lightning', 'omegaconf'], stdout=subprocess.PIPE).stdout.decode('utf-8')
print(multipip_res)
@ -1697,6 +1693,11 @@ RN50x64 = False #@param{type:"boolean"}
#@markdown If you're having issues with model downloads, check this to compare SHA's:
check_model_SHA = False #@param{type:"boolean"}
def download_models(diffusion_model,use_secondary_model,fallback=False):
model_256_downloaded = False
model_512_downloaded = False
model_secondary_downloaded = False
model_256_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'
model_512_SHA = '9c111ab89e214862b76e1fa6a1b3f1d329b1a88281885943d2cdbe357ad57648'
model_secondary_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'
@ -1705,10 +1706,18 @@ model_256_link = 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/
model_512_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/512x512_diffusion_uncond_finetune_008100.pt'
model_secondary_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth'
model_256_link_fb = 'https://www.dropbox.com/s/9tqnqo930mpnpcn/256x256_diffusion_uncond.pt'
model_512_link_fb = 'https://www.dropbox.com/s/yjqvhu6l6l0r2eh/512x512_diffusion_uncond_finetune_008100.pt'
model_secondary_link_fb = 'https://www.dropbox.com/s/luv4fezod3r8d2n/secondary_model_imagenet_2.pth'
model_256_path = f'{model_path}/256x256_diffusion_uncond.pt'
model_512_path = f'{model_path}/512x512_diffusion_uncond_finetune_008100.pt'
model_secondary_path = f'{model_path}/secondary_model_imagenet_2.pth'
if fallback:
model_256_link = model_256_link_fb
model_512_link = model_512_link_fb
model_secondary_link = model_secondary_link_fb
# Download the diffusion model
if diffusion_model == '256x256_diffusion_uncond':
if os.path.exists(model_256_path) and check_model_SHA:
@ -1722,12 +1731,20 @@ if diffusion_model == '256x256_diffusion_uncond':
else:
print("256 Model SHA doesn't match, redownloading...")
wget(model_256_link, model_path)
if os.path.exists(model_256_path):
model_256_downloaded = True
else:
print('First URL Failed using FallBack')
download_models(diffusion_model,use_secondary_model,True)
elif os.path.exists(model_256_path) and not check_model_SHA or model_256_downloaded == True:
print('256 Model already downloaded, check check_model_SHA if the file is corrupt')
else:
wget(model_256_link, model_path)
if os.path.exists(model_256_path):
model_256_downloaded = True
else:
print('First URL Failed using FallBack')
download_models(diffusion_model,True)
elif diffusion_model == '512x512_diffusion_uncond_finetune_008100':
if os.path.exists(model_512_path) and check_model_SHA:
print('Checking 512 Diffusion File')
@ -1736,18 +1753,24 @@ elif diffusion_model == '512x512_diffusion_uncond_finetune_008100':
hash = hashlib.sha256(bytes).hexdigest();
if hash == model_512_SHA:
print('512 Model SHA matches')
if os.path.exists(model_512_path):
model_512_downloaded = True
else:
print('First URL Failed using FallBack')
download_models(diffusion_model,use_secondary_model,True)
else:
print("512 Model SHA doesn't match, redownloading...")
wget(model_512_link, model_path)
if os.path.exists(model_512_path):
model_512_downloaded = True
else:
print('First URL Failed using FallBack')
download_models(diffusion_model,use_secondary_model,True)
elif os.path.exists(model_512_path) and not check_model_SHA or model_512_downloaded == True:
print('512 Model already downloaded, check check_model_SHA if the file is corrupt')
else:
wget(model_512_link, model_path)
model_512_downloaded = True
# Download the secondary diffusion model v2
if use_secondary_model == True:
if os.path.exists(model_secondary_path) and check_model_SHA:
@ -1761,12 +1784,22 @@ if use_secondary_model == True:
else:
print("Secondary Model SHA doesn't match, redownloading...")
wget(model_secondary_link, model_path)
if os.path.exists(model_secondary_path):
model_secondary_downloaded = True
else:
print('First URL Failed using FallBack')
download_models(diffusion_model,use_secondary_model,True)
elif os.path.exists(model_secondary_path) and not check_model_SHA or model_secondary_downloaded == True:
print('Secondary Model already downloaded, check check_model_SHA if the file is corrupt')
else:
wget(model_secondary_link, model_path)
if os.path.exists(model_secondary_path):
model_secondary_downloaded = True
else:
print('First URL Failed using FallBack')
download_models(diffusion_model,use_secondary_model,True)
download_models(diffusion_model,use_secondary_model)
model_config = model_and_diffusion_defaults()
if diffusion_model == '512x512_diffusion_uncond_finetune_008100':

Loading…
Cancel
Save