|
|
|
@ -450,7 +450,7 @@ |
|
|
|
|
"if not path_exists(f'{model_path}'):\n", |
|
|
|
|
" pathlib.Path(model_path).mkdir(parents=True, exist_ok=True)\n", |
|
|
|
|
"if not path_exists(f'{model_path}/dpt_large-midas-2f21e586.pt'):\n", |
|
|
|
|
" wget(\"https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt\", out=model_path)\n", |
|
|
|
|
" wget(\"https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt\", model_path)\n", |
|
|
|
|
"\n", |
|
|
|
|
"import sys\n", |
|
|
|
|
"import torch\n", |
|
|
|
@ -547,7 +547,7 @@ |
|
|
|
|
" if is_colab:\n", |
|
|
|
|
" gitclone(\"https://github.com/shariqfarooq123/AdaBins.git\")\n", |
|
|
|
|
" if not path_exists(f'{model_path}/AdaBins_nyu.pt'):\n", |
|
|
|
|
" wget(\"https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt\", out=model_path)\n", |
|
|
|
|
" wget(\"https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt\", model_path)\n", |
|
|
|
|
" pathlib.Path(\"pretrained\").mkdir(parents=True, exist_ok=True)\n", |
|
|
|
|
" shutil.copyfile(f\"{model_path}/AdaBins_nyu.pt\", \"pretrained/AdaBins_nyu.pt\")\n", |
|
|
|
|
" sys.path.append('./AdaBins')\n", |
|
|
|
@ -2191,12 +2191,12 @@ |
|
|
|
|
" model_256_downloaded = True\n", |
|
|
|
|
" else: \n", |
|
|
|
|
" print(\"256 Model SHA doesn't match, redownloading...\")\n", |
|
|
|
|
" wget(model_256_link, out=model_path)\n", |
|
|
|
|
" wget(model_256_link, model_path)\n", |
|
|
|
|
" model_256_downloaded = True\n", |
|
|
|
|
" elif os.path.exists(model_256_path) and not check_model_SHA or model_256_downloaded == True:\n", |
|
|
|
|
" print('256 Model already downloaded, check check_model_SHA if the file is corrupt')\n", |
|
|
|
|
" else: \n", |
|
|
|
|
" wget(model_256_link, out=model_path)\n", |
|
|
|
|
" wget(model_256_link, model_path)\n", |
|
|
|
|
" model_256_downloaded = True\n", |
|
|
|
|
"elif diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n", |
|
|
|
|
" if os.path.exists(model_512_path) and check_model_SHA:\n", |
|
|
|
@ -2209,12 +2209,12 @@ |
|
|
|
|
" model_512_downloaded = True\n", |
|
|
|
|
" else: \n", |
|
|
|
|
" print(\"512 Model SHA doesn't match, redownloading...\")\n", |
|
|
|
|
" wget(model_512_link, out=model_path)\n", |
|
|
|
|
" wget(model_512_link, model_path)\n", |
|
|
|
|
" model_512_downloaded = True\n", |
|
|
|
|
" elif os.path.exists(model_512_path) and not check_model_SHA or model_512_downloaded == True:\n", |
|
|
|
|
" print('512 Model already downloaded, check check_model_SHA if the file is corrupt')\n", |
|
|
|
|
" else: \n", |
|
|
|
|
" wget(model_512_link, out=model_path)\n", |
|
|
|
|
" wget(model_512_link, model_path)\n", |
|
|
|
|
" model_512_downloaded = True\n", |
|
|
|
|
"\n", |
|
|
|
|
"\n", |
|
|
|
@ -2230,12 +2230,12 @@ |
|
|
|
|
" model_secondary_downloaded = True\n", |
|
|
|
|
" else: \n", |
|
|
|
|
" print(\"Secondary Model SHA doesn't match, redownloading...\")\n", |
|
|
|
|
" wget(model_secondary_link, out=model_path)\n", |
|
|
|
|
" wget(model_secondary_link, model_path)\n", |
|
|
|
|
" model_secondary_downloaded = True\n", |
|
|
|
|
" elif os.path.exists(model_secondary_path) and not check_model_SHA or model_secondary_downloaded == True:\n", |
|
|
|
|
" print('Secondary Model already downloaded, check check_model_SHA if the file is corrupt')\n", |
|
|
|
|
" else: \n", |
|
|
|
|
" wget(model_secondary_link, out=model_path)\n", |
|
|
|
|
" wget(model_secondary_link, model_path)\n", |
|
|
|
|
" model_secondary_downloaded = True\n", |
|
|
|
|
"\n", |
|
|
|
|
"model_config = model_and_diffusion_defaults()\n", |
|
|
|
@ -2299,7 +2299,7 @@ |
|
|
|
|
"if SLIPB16:\n", |
|
|
|
|
" SLIPB16model = SLIP_VITB16(ssl_mlp_dim=4096, ssl_emb_dim=256)\n", |
|
|
|
|
" if not os.path.exists(f'{model_path}/slip_base_100ep.pt'):\n", |
|
|
|
|
" wget(\"https://dl.fbaipublicfiles.com/slip/slip_base_100ep.pt\", out=model_path)\n", |
|
|
|
|
" wget(\"https://dl.fbaipublicfiles.com/slip/slip_base_100ep.pt\", model_path)\n", |
|
|
|
|
" sd = torch.load(f'{model_path}/slip_base_100ep.pt')\n", |
|
|
|
|
" real_sd = {}\n", |
|
|
|
|
" for k, v in sd['state_dict'].items():\n", |
|
|
|
@ -2313,7 +2313,7 @@ |
|
|
|
|
"if SLIPL16:\n", |
|
|
|
|
" SLIPL16model = SLIP_VITL16(ssl_mlp_dim=4096, ssl_emb_dim=256)\n", |
|
|
|
|
" if not os.path.exists(f'{model_path}/slip_large_100ep.pt'):\n", |
|
|
|
|
" wget(\"https://dl.fbaipublicfiles.com/slip/slip_large_100ep.pt\", out=model_path)\n", |
|
|
|
|
" wget(\"https://dl.fbaipublicfiles.com/slip/slip_large_100ep.pt\", model_path)\n", |
|
|
|
|
" sd = torch.load(f'{model_path}/slip_large_100ep.pt')\n", |
|
|
|
|
" real_sd = {}\n", |
|
|
|
|
" for k, v in sd['state_dict'].items():\n", |
|
|
|
|