From 1fceed19b7d51862ce054946f13561fa17318755 Mon Sep 17 00:00:00 2001 From: MSFTserver Date: Mon, 2 May 2022 16:53:48 -0700 Subject: [PATCH] add fallback model URLS --- Disco_Diffusion.ipynb | 169 +++++++++++++++++++++++++----------------- disco.py | 169 +++++++++++++++++++++++++----------------- 2 files changed, 202 insertions(+), 136 deletions(-) diff --git a/Disco_Diffusion.ipynb b/Disco_Diffusion.ipynb index 5cfa846..6963f6c 100644 --- a/Disco_Diffusion.ipynb +++ b/Disco_Diffusion.ipynb @@ -478,10 +478,6 @@ " root_path = os.getcwd()\n", " model_path = f'{root_path}/models'\n", "\n", - "model_256_downloaded = False\n", - "model_512_downloaded = False\n", - "model_secondary_downloaded = False\n", - "\n", "multipip_res = subprocess.run(['pip', 'install', 'lpips', 'datetime', 'timm', 'ftfy', 'einops', 'pytorch-lightning', 'omegaconf'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", "print(multipip_res)\n", "\n", @@ -1744,76 +1740,113 @@ "#@markdown If you're having issues with model downloads, check this to compare SHA's:\n", "check_model_SHA = False #@param{type:\"boolean\"}\n", "\n", - "model_256_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'\n", - "model_512_SHA = '9c111ab89e214862b76e1fa6a1b3f1d329b1a88281885943d2cdbe357ad57648'\n", - "model_secondary_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'\n", - "\n", - "model_256_link = 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'\n", - "model_512_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/512x512_diffusion_uncond_finetune_008100.pt'\n", - "model_secondary_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth'\n", - "\n", - "model_256_path = f'{model_path}/256x256_diffusion_uncond.pt'\n", - "model_512_path = f'{model_path}/512x512_diffusion_uncond_finetune_008100.pt'\n", - "model_secondary_path = f'{model_path}/secondary_model_imagenet_2.pth'\n", - "\n", - "# Download the diffusion model\n", - "if diffusion_model == '256x256_diffusion_uncond':\n", - " if os.path.exists(model_256_path) and check_model_SHA:\n", - " print('Checking 256 Diffusion File')\n", - " with open(model_256_path,\"rb\") as f:\n", - " bytes = f.read() \n", - " hash = hashlib.sha256(bytes).hexdigest();\n", - " if hash == model_256_SHA:\n", - " print('256 Model SHA matches')\n", - " model_256_downloaded = True\n", - " else: \n", - " print(\"256 Model SHA doesn't match, redownloading...\")\n", + "def download_models(diffusion_model,use_secondary_model,fallback=False):\n", + " model_256_downloaded = False\n", + " model_512_downloaded = False\n", + " model_secondary_downloaded = False\n", + "\n", + " model_256_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'\n", + " model_512_SHA = '9c111ab89e214862b76e1fa6a1b3f1d329b1a88281885943d2cdbe357ad57648'\n", + " model_secondary_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'\n", + "\n", + " model_256_link = 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'\n", + " model_512_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/512x512_diffusion_uncond_finetune_008100.pt'\n", + " model_secondary_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth'\n", + "\n", + " model_256_link_fb = 'https://www.dropbox.com/s/9tqnqo930mpnpcn/256x256_diffusion_uncond.pt'\n", + " model_512_link_fb = 'https://www.dropbox.com/s/yjqvhu6l6l0r2eh/512x512_diffusion_uncond_finetune_008100.pt'\n", + " model_secondary_link_fb = 'https://www.dropbox.com/s/luv4fezod3r8d2n/secondary_model_imagenet_2.pth'\n", + "\n", + " model_256_path = f'{model_path}/256x256_diffusion_uncond.pt'\n", + " model_512_path = f'{model_path}/512x512_diffusion_uncond_finetune_008100.pt'\n", + " model_secondary_path = f'{model_path}/secondary_model_imagenet_2.pth'\n", + "\n", + " if fallback:\n", + " model_256_link = model_256_link_fb\n", + " model_512_link = model_512_link_fb\n", + " model_secondary_link = model_secondary_link_fb\n", + " # Download the diffusion model\n", + " if diffusion_model == '256x256_diffusion_uncond':\n", + " if os.path.exists(model_256_path) and check_model_SHA:\n", + " print('Checking 256 Diffusion File')\n", + " with open(model_256_path,\"rb\") as f:\n", + " bytes = f.read() \n", + " hash = hashlib.sha256(bytes).hexdigest();\n", + " if hash == model_256_SHA:\n", + " print('256 Model SHA matches')\n", + " model_256_downloaded = True\n", + " else: \n", + " print(\"256 Model SHA doesn't match, redownloading...\")\n", + " wget(model_256_link, model_path)\n", + " if os.path.exists(model_256_path):\n", + " model_256_downloaded = True\n", + " else:\n", + " print('First URL Failed using FallBack')\n", + " download_models(diffusion_model,use_secondary_model,True)\n", + " elif os.path.exists(model_256_path) and not check_model_SHA or model_256_downloaded == True:\n", + " print('256 Model already downloaded, check check_model_SHA if the file is corrupt')\n", + " else: \n", " wget(model_256_link, model_path)\n", - " model_256_downloaded = True\n", - " elif os.path.exists(model_256_path) and not check_model_SHA or model_256_downloaded == True:\n", - " print('256 Model already downloaded, check check_model_SHA if the file is corrupt')\n", - " else: \n", - " wget(model_256_link, model_path)\n", - " model_256_downloaded = True\n", - "elif diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n", - " if os.path.exists(model_512_path) and check_model_SHA:\n", - " print('Checking 512 Diffusion File')\n", - " with open(model_512_path,\"rb\") as f:\n", - " bytes = f.read() \n", - " hash = hashlib.sha256(bytes).hexdigest();\n", - " if hash == model_512_SHA:\n", - " print('512 Model SHA matches')\n", - " model_512_downloaded = True\n", + " if os.path.exists(model_256_path):\n", + " model_256_downloaded = True\n", + " else:\n", + " print('First URL Failed using FallBack')\n", + " download_models(diffusion_model,True)\n", + " elif diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n", + " if os.path.exists(model_512_path) and check_model_SHA:\n", + " print('Checking 512 Diffusion File')\n", + " with open(model_512_path,\"rb\") as f:\n", + " bytes = f.read() \n", + " hash = hashlib.sha256(bytes).hexdigest();\n", + " if hash == model_512_SHA:\n", + " print('512 Model SHA matches')\n", + " if os.path.exists(model_512_path):\n", + " model_512_downloaded = True\n", + " else:\n", + " print('First URL Failed using FallBack')\n", + " download_models(diffusion_model,use_secondary_model,True)\n", + " else: \n", + " print(\"512 Model SHA doesn't match, redownloading...\")\n", + " wget(model_512_link, model_path)\n", + " if os.path.exists(model_512_path):\n", + " model_512_downloaded = True\n", + " else:\n", + " print('First URL Failed using FallBack')\n", + " download_models(diffusion_model,use_secondary_model,True)\n", + " elif os.path.exists(model_512_path) and not check_model_SHA or model_512_downloaded == True:\n", + " print('512 Model already downloaded, check check_model_SHA if the file is corrupt')\n", " else: \n", - " print(\"512 Model SHA doesn't match, redownloading...\")\n", " wget(model_512_link, model_path)\n", " model_512_downloaded = True\n", - " elif os.path.exists(model_512_path) and not check_model_SHA or model_512_downloaded == True:\n", - " print('512 Model already downloaded, check check_model_SHA if the file is corrupt')\n", - " else: \n", - " wget(model_512_link, model_path)\n", - " model_512_downloaded = True\n", - "\n", - "\n", - "# Download the secondary diffusion model v2\n", - "if use_secondary_model == True:\n", - " if os.path.exists(model_secondary_path) and check_model_SHA:\n", - " print('Checking Secondary Diffusion File')\n", - " with open(model_secondary_path,\"rb\") as f:\n", - " bytes = f.read() \n", - " hash = hashlib.sha256(bytes).hexdigest();\n", - " if hash == model_secondary_SHA:\n", - " print('Secondary Model SHA matches')\n", - " model_secondary_downloaded = True\n", + " # Download the secondary diffusion model v2\n", + " if use_secondary_model == True:\n", + " if os.path.exists(model_secondary_path) and check_model_SHA:\n", + " print('Checking Secondary Diffusion File')\n", + " with open(model_secondary_path,\"rb\") as f:\n", + " bytes = f.read() \n", + " hash = hashlib.sha256(bytes).hexdigest();\n", + " if hash == model_secondary_SHA:\n", + " print('Secondary Model SHA matches')\n", + " model_secondary_downloaded = True\n", + " else: \n", + " print(\"Secondary Model SHA doesn't match, redownloading...\")\n", + " wget(model_secondary_link, model_path)\n", + " if os.path.exists(model_secondary_path):\n", + " model_secondary_downloaded = True\n", + " else:\n", + " print('First URL Failed using FallBack')\n", + " download_models(diffusion_model,use_secondary_model,True)\n", + " elif os.path.exists(model_secondary_path) and not check_model_SHA or model_secondary_downloaded == True:\n", + " print('Secondary Model already downloaded, check check_model_SHA if the file is corrupt')\n", " else: \n", - " print(\"Secondary Model SHA doesn't match, redownloading...\")\n", " wget(model_secondary_link, model_path)\n", - " model_secondary_downloaded = True\n", - " elif os.path.exists(model_secondary_path) and not check_model_SHA or model_secondary_downloaded == True:\n", - " print('Secondary Model already downloaded, check check_model_SHA if the file is corrupt')\n", - " else: \n", - " wget(model_secondary_link, model_path)\n", - " model_secondary_downloaded = True\n", + " if os.path.exists(model_secondary_path):\n", + " model_secondary_downloaded = True\n", + " else:\n", + " print('First URL Failed using FallBack')\n", + " download_models(diffusion_model,use_secondary_model,True)\n", + "\n", + "download_models(diffusion_model,use_secondary_model)\n", "\n", "model_config = model_and_diffusion_defaults()\n", "if diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n", diff --git a/disco.py b/disco.py index 1bf172e..0379e2e 100644 --- a/disco.py +++ b/disco.py @@ -451,10 +451,6 @@ else: root_path = os.getcwd() model_path = f'{root_path}/models' -model_256_downloaded = False -model_512_downloaded = False -model_secondary_downloaded = False - multipip_res = subprocess.run(['pip', 'install', 'lpips', 'datetime', 'timm', 'ftfy', 'einops', 'pytorch-lightning', 'omegaconf'], stdout=subprocess.PIPE).stdout.decode('utf-8') print(multipip_res) @@ -1697,76 +1693,113 @@ RN50x64 = False #@param{type:"boolean"} #@markdown If you're having issues with model downloads, check this to compare SHA's: check_model_SHA = False #@param{type:"boolean"} -model_256_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a' -model_512_SHA = '9c111ab89e214862b76e1fa6a1b3f1d329b1a88281885943d2cdbe357ad57648' -model_secondary_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a' - -model_256_link = 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt' -model_512_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/512x512_diffusion_uncond_finetune_008100.pt' -model_secondary_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth' - -model_256_path = f'{model_path}/256x256_diffusion_uncond.pt' -model_512_path = f'{model_path}/512x512_diffusion_uncond_finetune_008100.pt' -model_secondary_path = f'{model_path}/secondary_model_imagenet_2.pth' - -# Download the diffusion model -if diffusion_model == '256x256_diffusion_uncond': - if os.path.exists(model_256_path) and check_model_SHA: - print('Checking 256 Diffusion File') - with open(model_256_path,"rb") as f: - bytes = f.read() - hash = hashlib.sha256(bytes).hexdigest(); - if hash == model_256_SHA: - print('256 Model SHA matches') - model_256_downloaded = True - else: - print("256 Model SHA doesn't match, redownloading...") +def download_models(diffusion_model,use_secondary_model,fallback=False): + model_256_downloaded = False + model_512_downloaded = False + model_secondary_downloaded = False + + model_256_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a' + model_512_SHA = '9c111ab89e214862b76e1fa6a1b3f1d329b1a88281885943d2cdbe357ad57648' + model_secondary_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a' + + model_256_link = 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt' + model_512_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/512x512_diffusion_uncond_finetune_008100.pt' + model_secondary_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth' + + model_256_link_fb = 'https://www.dropbox.com/s/9tqnqo930mpnpcn/256x256_diffusion_uncond.pt' + model_512_link_fb = 'https://www.dropbox.com/s/yjqvhu6l6l0r2eh/512x512_diffusion_uncond_finetune_008100.pt' + model_secondary_link_fb = 'https://www.dropbox.com/s/luv4fezod3r8d2n/secondary_model_imagenet_2.pth' + + model_256_path = f'{model_path}/256x256_diffusion_uncond.pt' + model_512_path = f'{model_path}/512x512_diffusion_uncond_finetune_008100.pt' + model_secondary_path = f'{model_path}/secondary_model_imagenet_2.pth' + + if fallback: + model_256_link = model_256_link_fb + model_512_link = model_512_link_fb + model_secondary_link = model_secondary_link_fb + # Download the diffusion model + if diffusion_model == '256x256_diffusion_uncond': + if os.path.exists(model_256_path) and check_model_SHA: + print('Checking 256 Diffusion File') + with open(model_256_path,"rb") as f: + bytes = f.read() + hash = hashlib.sha256(bytes).hexdigest(); + if hash == model_256_SHA: + print('256 Model SHA matches') + model_256_downloaded = True + else: + print("256 Model SHA doesn't match, redownloading...") + wget(model_256_link, model_path) + if os.path.exists(model_256_path): + model_256_downloaded = True + else: + print('First URL Failed using FallBack') + download_models(diffusion_model,use_secondary_model,True) + elif os.path.exists(model_256_path) and not check_model_SHA or model_256_downloaded == True: + print('256 Model already downloaded, check check_model_SHA if the file is corrupt') + else: wget(model_256_link, model_path) - model_256_downloaded = True - elif os.path.exists(model_256_path) and not check_model_SHA or model_256_downloaded == True: - print('256 Model already downloaded, check check_model_SHA if the file is corrupt') - else: - wget(model_256_link, model_path) - model_256_downloaded = True -elif diffusion_model == '512x512_diffusion_uncond_finetune_008100': - if os.path.exists(model_512_path) and check_model_SHA: - print('Checking 512 Diffusion File') - with open(model_512_path,"rb") as f: - bytes = f.read() - hash = hashlib.sha256(bytes).hexdigest(); - if hash == model_512_SHA: - print('512 Model SHA matches') - model_512_downloaded = True + if os.path.exists(model_256_path): + model_256_downloaded = True + else: + print('First URL Failed using FallBack') + download_models(diffusion_model,True) + elif diffusion_model == '512x512_diffusion_uncond_finetune_008100': + if os.path.exists(model_512_path) and check_model_SHA: + print('Checking 512 Diffusion File') + with open(model_512_path,"rb") as f: + bytes = f.read() + hash = hashlib.sha256(bytes).hexdigest(); + if hash == model_512_SHA: + print('512 Model SHA matches') + if os.path.exists(model_512_path): + model_512_downloaded = True + else: + print('First URL Failed using FallBack') + download_models(diffusion_model,use_secondary_model,True) + else: + print("512 Model SHA doesn't match, redownloading...") + wget(model_512_link, model_path) + if os.path.exists(model_512_path): + model_512_downloaded = True + else: + print('First URL Failed using FallBack') + download_models(diffusion_model,use_secondary_model,True) + elif os.path.exists(model_512_path) and not check_model_SHA or model_512_downloaded == True: + print('512 Model already downloaded, check check_model_SHA if the file is corrupt') else: - print("512 Model SHA doesn't match, redownloading...") wget(model_512_link, model_path) model_512_downloaded = True - elif os.path.exists(model_512_path) and not check_model_SHA or model_512_downloaded == True: - print('512 Model already downloaded, check check_model_SHA if the file is corrupt') - else: - wget(model_512_link, model_path) - model_512_downloaded = True - - -# Download the secondary diffusion model v2 -if use_secondary_model == True: - if os.path.exists(model_secondary_path) and check_model_SHA: - print('Checking Secondary Diffusion File') - with open(model_secondary_path,"rb") as f: - bytes = f.read() - hash = hashlib.sha256(bytes).hexdigest(); - if hash == model_secondary_SHA: - print('Secondary Model SHA matches') - model_secondary_downloaded = True + # Download the secondary diffusion model v2 + if use_secondary_model == True: + if os.path.exists(model_secondary_path) and check_model_SHA: + print('Checking Secondary Diffusion File') + with open(model_secondary_path,"rb") as f: + bytes = f.read() + hash = hashlib.sha256(bytes).hexdigest(); + if hash == model_secondary_SHA: + print('Secondary Model SHA matches') + model_secondary_downloaded = True + else: + print("Secondary Model SHA doesn't match, redownloading...") + wget(model_secondary_link, model_path) + if os.path.exists(model_secondary_path): + model_secondary_downloaded = True + else: + print('First URL Failed using FallBack') + download_models(diffusion_model,use_secondary_model,True) + elif os.path.exists(model_secondary_path) and not check_model_SHA or model_secondary_downloaded == True: + print('Secondary Model already downloaded, check check_model_SHA if the file is corrupt') else: - print("Secondary Model SHA doesn't match, redownloading...") wget(model_secondary_link, model_path) - model_secondary_downloaded = True - elif os.path.exists(model_secondary_path) and not check_model_SHA or model_secondary_downloaded == True: - print('Secondary Model already downloaded, check check_model_SHA if the file is corrupt') - else: - wget(model_secondary_link, model_path) - model_secondary_downloaded = True + if os.path.exists(model_secondary_path): + model_secondary_downloaded = True + else: + print('First URL Failed using FallBack') + download_models(diffusion_model,use_secondary_model,True) + +download_models(diffusion_model,use_secondary_model) model_config = model_and_diffusion_defaults() if diffusion_model == '512x512_diffusion_uncond_finetune_008100':