add fallback model URLS

pull/67/head
MSFTserver 2 years ago
parent af10d0d737
commit 1fceed19b7
  1. 169
      Disco_Diffusion.ipynb
  2. 169
      disco.py

@ -478,10 +478,6 @@
" root_path = os.getcwd()\n", " root_path = os.getcwd()\n",
" model_path = f'{root_path}/models'\n", " model_path = f'{root_path}/models'\n",
"\n", "\n",
"model_256_downloaded = False\n",
"model_512_downloaded = False\n",
"model_secondary_downloaded = False\n",
"\n",
"multipip_res = subprocess.run(['pip', 'install', 'lpips', 'datetime', 'timm', 'ftfy', 'einops', 'pytorch-lightning', 'omegaconf'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", "multipip_res = subprocess.run(['pip', 'install', 'lpips', 'datetime', 'timm', 'ftfy', 'einops', 'pytorch-lightning', 'omegaconf'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
"print(multipip_res)\n", "print(multipip_res)\n",
"\n", "\n",
@ -1744,76 +1740,113 @@
"#@markdown If you're having issues with model downloads, check this to compare SHA's:\n", "#@markdown If you're having issues with model downloads, check this to compare SHA's:\n",
"check_model_SHA = False #@param{type:\"boolean\"}\n", "check_model_SHA = False #@param{type:\"boolean\"}\n",
"\n", "\n",
"model_256_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'\n", "def download_models(diffusion_model,use_secondary_model,fallback=False):\n",
"model_512_SHA = '9c111ab89e214862b76e1fa6a1b3f1d329b1a88281885943d2cdbe357ad57648'\n", " model_256_downloaded = False\n",
"model_secondary_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'\n", " model_512_downloaded = False\n",
"\n", " model_secondary_downloaded = False\n",
"model_256_link = 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'\n", "\n",
"model_512_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/512x512_diffusion_uncond_finetune_008100.pt'\n", " model_256_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'\n",
"model_secondary_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth'\n", " model_512_SHA = '9c111ab89e214862b76e1fa6a1b3f1d329b1a88281885943d2cdbe357ad57648'\n",
"\n", " model_secondary_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'\n",
"model_256_path = f'{model_path}/256x256_diffusion_uncond.pt'\n", "\n",
"model_512_path = f'{model_path}/512x512_diffusion_uncond_finetune_008100.pt'\n", " model_256_link = 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'\n",
"model_secondary_path = f'{model_path}/secondary_model_imagenet_2.pth'\n", " model_512_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/512x512_diffusion_uncond_finetune_008100.pt'\n",
"\n", " model_secondary_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth'\n",
"# Download the diffusion model\n", "\n",
"if diffusion_model == '256x256_diffusion_uncond':\n", " model_256_link_fb = 'https://www.dropbox.com/s/9tqnqo930mpnpcn/256x256_diffusion_uncond.pt'\n",
" if os.path.exists(model_256_path) and check_model_SHA:\n", " model_512_link_fb = 'https://www.dropbox.com/s/yjqvhu6l6l0r2eh/512x512_diffusion_uncond_finetune_008100.pt'\n",
" print('Checking 256 Diffusion File')\n", " model_secondary_link_fb = 'https://www.dropbox.com/s/luv4fezod3r8d2n/secondary_model_imagenet_2.pth'\n",
" with open(model_256_path,\"rb\") as f:\n", "\n",
" bytes = f.read() \n", " model_256_path = f'{model_path}/256x256_diffusion_uncond.pt'\n",
" hash = hashlib.sha256(bytes).hexdigest();\n", " model_512_path = f'{model_path}/512x512_diffusion_uncond_finetune_008100.pt'\n",
" if hash == model_256_SHA:\n", " model_secondary_path = f'{model_path}/secondary_model_imagenet_2.pth'\n",
" print('256 Model SHA matches')\n", "\n",
" model_256_downloaded = True\n", " if fallback:\n",
" else: \n", " model_256_link = model_256_link_fb\n",
" print(\"256 Model SHA doesn't match, redownloading...\")\n", " model_512_link = model_512_link_fb\n",
" model_secondary_link = model_secondary_link_fb\n",
" # Download the diffusion model\n",
" if diffusion_model == '256x256_diffusion_uncond':\n",
" if os.path.exists(model_256_path) and check_model_SHA:\n",
" print('Checking 256 Diffusion File')\n",
" with open(model_256_path,\"rb\") as f:\n",
" bytes = f.read() \n",
" hash = hashlib.sha256(bytes).hexdigest();\n",
" if hash == model_256_SHA:\n",
" print('256 Model SHA matches')\n",
" model_256_downloaded = True\n",
" else: \n",
" print(\"256 Model SHA doesn't match, redownloading...\")\n",
" wget(model_256_link, model_path)\n",
" if os.path.exists(model_256_path):\n",
" model_256_downloaded = True\n",
" else:\n",
" print('First URL Failed using FallBack')\n",
" download_models(diffusion_model,use_secondary_model,True)\n",
" elif os.path.exists(model_256_path) and not check_model_SHA or model_256_downloaded == True:\n",
" print('256 Model already downloaded, check check_model_SHA if the file is corrupt')\n",
" else: \n",
" wget(model_256_link, model_path)\n", " wget(model_256_link, model_path)\n",
" model_256_downloaded = True\n", " if os.path.exists(model_256_path):\n",
" elif os.path.exists(model_256_path) and not check_model_SHA or model_256_downloaded == True:\n", " model_256_downloaded = True\n",
" print('256 Model already downloaded, check check_model_SHA if the file is corrupt')\n", " else:\n",
" else: \n", " print('First URL Failed using FallBack')\n",
" wget(model_256_link, model_path)\n", " download_models(diffusion_model,True)\n",
" model_256_downloaded = True\n", " elif diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n",
"elif diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n", " if os.path.exists(model_512_path) and check_model_SHA:\n",
" if os.path.exists(model_512_path) and check_model_SHA:\n", " print('Checking 512 Diffusion File')\n",
" print('Checking 512 Diffusion File')\n", " with open(model_512_path,\"rb\") as f:\n",
" with open(model_512_path,\"rb\") as f:\n", " bytes = f.read() \n",
" bytes = f.read() \n", " hash = hashlib.sha256(bytes).hexdigest();\n",
" hash = hashlib.sha256(bytes).hexdigest();\n", " if hash == model_512_SHA:\n",
" if hash == model_512_SHA:\n", " print('512 Model SHA matches')\n",
" print('512 Model SHA matches')\n", " if os.path.exists(model_512_path):\n",
" model_512_downloaded = True\n", " model_512_downloaded = True\n",
" else:\n",
" print('First URL Failed using FallBack')\n",
" download_models(diffusion_model,use_secondary_model,True)\n",
" else: \n",
" print(\"512 Model SHA doesn't match, redownloading...\")\n",
" wget(model_512_link, model_path)\n",
" if os.path.exists(model_512_path):\n",
" model_512_downloaded = True\n",
" else:\n",
" print('First URL Failed using FallBack')\n",
" download_models(diffusion_model,use_secondary_model,True)\n",
" elif os.path.exists(model_512_path) and not check_model_SHA or model_512_downloaded == True:\n",
" print('512 Model already downloaded, check check_model_SHA if the file is corrupt')\n",
" else: \n", " else: \n",
" print(\"512 Model SHA doesn't match, redownloading...\")\n",
" wget(model_512_link, model_path)\n", " wget(model_512_link, model_path)\n",
" model_512_downloaded = True\n", " model_512_downloaded = True\n",
" elif os.path.exists(model_512_path) and not check_model_SHA or model_512_downloaded == True:\n", " # Download the secondary diffusion model v2\n",
" print('512 Model already downloaded, check check_model_SHA if the file is corrupt')\n", " if use_secondary_model == True:\n",
" else: \n", " if os.path.exists(model_secondary_path) and check_model_SHA:\n",
" wget(model_512_link, model_path)\n", " print('Checking Secondary Diffusion File')\n",
" model_512_downloaded = True\n", " with open(model_secondary_path,\"rb\") as f:\n",
"\n", " bytes = f.read() \n",
"\n", " hash = hashlib.sha256(bytes).hexdigest();\n",
"# Download the secondary diffusion model v2\n", " if hash == model_secondary_SHA:\n",
"if use_secondary_model == True:\n", " print('Secondary Model SHA matches')\n",
" if os.path.exists(model_secondary_path) and check_model_SHA:\n", " model_secondary_downloaded = True\n",
" print('Checking Secondary Diffusion File')\n", " else: \n",
" with open(model_secondary_path,\"rb\") as f:\n", " print(\"Secondary Model SHA doesn't match, redownloading...\")\n",
" bytes = f.read() \n", " wget(model_secondary_link, model_path)\n",
" hash = hashlib.sha256(bytes).hexdigest();\n", " if os.path.exists(model_secondary_path):\n",
" if hash == model_secondary_SHA:\n", " model_secondary_downloaded = True\n",
" print('Secondary Model SHA matches')\n", " else:\n",
" model_secondary_downloaded = True\n", " print('First URL Failed using FallBack')\n",
" download_models(diffusion_model,use_secondary_model,True)\n",
" elif os.path.exists(model_secondary_path) and not check_model_SHA or model_secondary_downloaded == True:\n",
" print('Secondary Model already downloaded, check check_model_SHA if the file is corrupt')\n",
" else: \n", " else: \n",
" print(\"Secondary Model SHA doesn't match, redownloading...\")\n",
" wget(model_secondary_link, model_path)\n", " wget(model_secondary_link, model_path)\n",
" model_secondary_downloaded = True\n", " if os.path.exists(model_secondary_path):\n",
" elif os.path.exists(model_secondary_path) and not check_model_SHA or model_secondary_downloaded == True:\n", " model_secondary_downloaded = True\n",
" print('Secondary Model already downloaded, check check_model_SHA if the file is corrupt')\n", " else:\n",
" else: \n", " print('First URL Failed using FallBack')\n",
" wget(model_secondary_link, model_path)\n", " download_models(diffusion_model,use_secondary_model,True)\n",
" model_secondary_downloaded = True\n", "\n",
"download_models(diffusion_model,use_secondary_model)\n",
"\n", "\n",
"model_config = model_and_diffusion_defaults()\n", "model_config = model_and_diffusion_defaults()\n",
"if diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n", "if diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n",

@ -451,10 +451,6 @@ else:
root_path = os.getcwd() root_path = os.getcwd()
model_path = f'{root_path}/models' model_path = f'{root_path}/models'
model_256_downloaded = False
model_512_downloaded = False
model_secondary_downloaded = False
multipip_res = subprocess.run(['pip', 'install', 'lpips', 'datetime', 'timm', 'ftfy', 'einops', 'pytorch-lightning', 'omegaconf'], stdout=subprocess.PIPE).stdout.decode('utf-8') multipip_res = subprocess.run(['pip', 'install', 'lpips', 'datetime', 'timm', 'ftfy', 'einops', 'pytorch-lightning', 'omegaconf'], stdout=subprocess.PIPE).stdout.decode('utf-8')
print(multipip_res) print(multipip_res)
@ -1697,76 +1693,113 @@ RN50x64 = False #@param{type:"boolean"}
#@markdown If you're having issues with model downloads, check this to compare SHA's: #@markdown If you're having issues with model downloads, check this to compare SHA's:
check_model_SHA = False #@param{type:"boolean"} check_model_SHA = False #@param{type:"boolean"}
model_256_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a' def download_models(diffusion_model,use_secondary_model,fallback=False):
model_512_SHA = '9c111ab89e214862b76e1fa6a1b3f1d329b1a88281885943d2cdbe357ad57648' model_256_downloaded = False
model_secondary_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a' model_512_downloaded = False
model_secondary_downloaded = False
model_256_link = 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'
model_512_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/512x512_diffusion_uncond_finetune_008100.pt' model_256_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'
model_secondary_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth' model_512_SHA = '9c111ab89e214862b76e1fa6a1b3f1d329b1a88281885943d2cdbe357ad57648'
model_secondary_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'
model_256_path = f'{model_path}/256x256_diffusion_uncond.pt'
model_512_path = f'{model_path}/512x512_diffusion_uncond_finetune_008100.pt' model_256_link = 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'
model_secondary_path = f'{model_path}/secondary_model_imagenet_2.pth' model_512_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/512x512_diffusion_uncond_finetune_008100.pt'
model_secondary_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth'
# Download the diffusion model
if diffusion_model == '256x256_diffusion_uncond': model_256_link_fb = 'https://www.dropbox.com/s/9tqnqo930mpnpcn/256x256_diffusion_uncond.pt'
if os.path.exists(model_256_path) and check_model_SHA: model_512_link_fb = 'https://www.dropbox.com/s/yjqvhu6l6l0r2eh/512x512_diffusion_uncond_finetune_008100.pt'
print('Checking 256 Diffusion File') model_secondary_link_fb = 'https://www.dropbox.com/s/luv4fezod3r8d2n/secondary_model_imagenet_2.pth'
with open(model_256_path,"rb") as f:
bytes = f.read() model_256_path = f'{model_path}/256x256_diffusion_uncond.pt'
hash = hashlib.sha256(bytes).hexdigest(); model_512_path = f'{model_path}/512x512_diffusion_uncond_finetune_008100.pt'
if hash == model_256_SHA: model_secondary_path = f'{model_path}/secondary_model_imagenet_2.pth'
print('256 Model SHA matches')
model_256_downloaded = True if fallback:
else: model_256_link = model_256_link_fb
print("256 Model SHA doesn't match, redownloading...") model_512_link = model_512_link_fb
model_secondary_link = model_secondary_link_fb
# Download the diffusion model
if diffusion_model == '256x256_diffusion_uncond':
if os.path.exists(model_256_path) and check_model_SHA:
print('Checking 256 Diffusion File')
with open(model_256_path,"rb") as f:
bytes = f.read()
hash = hashlib.sha256(bytes).hexdigest();
if hash == model_256_SHA:
print('256 Model SHA matches')
model_256_downloaded = True
else:
print("256 Model SHA doesn't match, redownloading...")
wget(model_256_link, model_path)
if os.path.exists(model_256_path):
model_256_downloaded = True
else:
print('First URL Failed using FallBack')
download_models(diffusion_model,use_secondary_model,True)
elif os.path.exists(model_256_path) and not check_model_SHA or model_256_downloaded == True:
print('256 Model already downloaded, check check_model_SHA if the file is corrupt')
else:
wget(model_256_link, model_path) wget(model_256_link, model_path)
model_256_downloaded = True if os.path.exists(model_256_path):
elif os.path.exists(model_256_path) and not check_model_SHA or model_256_downloaded == True: model_256_downloaded = True
print('256 Model already downloaded, check check_model_SHA if the file is corrupt') else:
else: print('First URL Failed using FallBack')
wget(model_256_link, model_path) download_models(diffusion_model,True)
model_256_downloaded = True elif diffusion_model == '512x512_diffusion_uncond_finetune_008100':
elif diffusion_model == '512x512_diffusion_uncond_finetune_008100': if os.path.exists(model_512_path) and check_model_SHA:
if os.path.exists(model_512_path) and check_model_SHA: print('Checking 512 Diffusion File')
print('Checking 512 Diffusion File') with open(model_512_path,"rb") as f:
with open(model_512_path,"rb") as f: bytes = f.read()
bytes = f.read() hash = hashlib.sha256(bytes).hexdigest();
hash = hashlib.sha256(bytes).hexdigest(); if hash == model_512_SHA:
if hash == model_512_SHA: print('512 Model SHA matches')
print('512 Model SHA matches') if os.path.exists(model_512_path):
model_512_downloaded = True model_512_downloaded = True
else:
print('First URL Failed using FallBack')
download_models(diffusion_model,use_secondary_model,True)
else:
print("512 Model SHA doesn't match, redownloading...")
wget(model_512_link, model_path)
if os.path.exists(model_512_path):
model_512_downloaded = True
else:
print('First URL Failed using FallBack')
download_models(diffusion_model,use_secondary_model,True)
elif os.path.exists(model_512_path) and not check_model_SHA or model_512_downloaded == True:
print('512 Model already downloaded, check check_model_SHA if the file is corrupt')
else: else:
print("512 Model SHA doesn't match, redownloading...")
wget(model_512_link, model_path) wget(model_512_link, model_path)
model_512_downloaded = True model_512_downloaded = True
elif os.path.exists(model_512_path) and not check_model_SHA or model_512_downloaded == True: # Download the secondary diffusion model v2
print('512 Model already downloaded, check check_model_SHA if the file is corrupt') if use_secondary_model == True:
else: if os.path.exists(model_secondary_path) and check_model_SHA:
wget(model_512_link, model_path) print('Checking Secondary Diffusion File')
model_512_downloaded = True with open(model_secondary_path,"rb") as f:
bytes = f.read()
hash = hashlib.sha256(bytes).hexdigest();
# Download the secondary diffusion model v2 if hash == model_secondary_SHA:
if use_secondary_model == True: print('Secondary Model SHA matches')
if os.path.exists(model_secondary_path) and check_model_SHA: model_secondary_downloaded = True
print('Checking Secondary Diffusion File') else:
with open(model_secondary_path,"rb") as f: print("Secondary Model SHA doesn't match, redownloading...")
bytes = f.read() wget(model_secondary_link, model_path)
hash = hashlib.sha256(bytes).hexdigest(); if os.path.exists(model_secondary_path):
if hash == model_secondary_SHA: model_secondary_downloaded = True
print('Secondary Model SHA matches') else:
model_secondary_downloaded = True print('First URL Failed using FallBack')
download_models(diffusion_model,use_secondary_model,True)
elif os.path.exists(model_secondary_path) and not check_model_SHA or model_secondary_downloaded == True:
print('Secondary Model already downloaded, check check_model_SHA if the file is corrupt')
else: else:
print("Secondary Model SHA doesn't match, redownloading...")
wget(model_secondary_link, model_path) wget(model_secondary_link, model_path)
model_secondary_downloaded = True if os.path.exists(model_secondary_path):
elif os.path.exists(model_secondary_path) and not check_model_SHA or model_secondary_downloaded == True: model_secondary_downloaded = True
print('Secondary Model already downloaded, check check_model_SHA if the file is corrupt') else:
else: print('First URL Failed using FallBack')
wget(model_secondary_link, model_path) download_models(diffusion_model,use_secondary_model,True)
model_secondary_downloaded = True
download_models(diffusion_model,use_secondary_model)
model_config = model_and_diffusion_defaults() model_config = model_and_diffusion_defaults()
if diffusion_model == '512x512_diffusion_uncond_finetune_008100': if diffusion_model == '512x512_diffusion_uncond_finetune_008100':

Loading…
Cancel
Save