Merge pull request #27 from MSFTserver/cross-platofrm

Cross Platform update
pull/30/head
Adam Letts 3 years ago committed by GitHub
commit 4acb77c2e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 785
      Disco_Diffusion.ipynb
  2. 18
      README.md
  3. 863
      disco.py

@ -241,7 +241,7 @@
"\n",
" IPython magic commands replaced by Python code\n",
"\n",
" v5.1 Update: Mar 30th 2022 - zippy / Chris Allen and gandamu / Adam Letts\n",
" v5.1 Update: Mar 30th 2022 - zippy / Chris Allen and gandamu / Adam Letts / MSFTserver aka HostsServer\n",
"\n",
" Integrated Turbo+Smooth features from Disco Diffusion Turbo -- just the implementation, without its defaults.\n",
"\n",
@ -253,8 +253,16 @@
"\n",
" Added video_init_seed_continuity option to make init video animations more continuous\n",
"\n",
" Removed pytorch3d from needing to be compiled with a lite version specifically made for Disco Diffusion\n",
"\n",
" Remove Super Resolution\n",
"\n",
" Remove SLIP Models\n",
"\n",
" Update for crossplatform support\n",
"\n",
" '''\n",
" )\n"
" )"
],
"outputs": [],
"execution_count": null
@ -362,9 +370,7 @@
},
"source": [
"#@title 1.2 Prepare Folders\n",
"import subprocess\n",
"import sys\n",
"import ipykernel\n",
"import subprocess, os, sys, ipykernel\n",
"\n",
"def gitclone(url):\n",
" res = subprocess.run(['git', 'clone', url], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
@ -403,7 +409,7 @@
" else:\n",
" root_path = '/content'\n",
"else:\n",
" root_path = '.'\n",
" root_path = os.getcwd()\n",
"\n",
"import os\n",
"def createPath(filepath):\n",
@ -416,13 +422,13 @@
"\n",
"if is_colab:\n",
" if google_drive and not save_models_to_google_drive or not google_drive:\n",
" model_path = '/content/model'\n",
" model_path = '/content/models'\n",
" createPath(model_path)\n",
" if google_drive and save_models_to_google_drive:\n",
" model_path = f'{root_path}/model'\n",
" model_path = f'{root_path}/models'\n",
" createPath(model_path)\n",
"else:\n",
" model_path = f'{root_path}/model'\n",
" model_path = f'{root_path}/models'\n",
" createPath(model_path)\n",
"\n",
"# libraries = f'{root_path}/libraries'\n",
@ -440,7 +446,7 @@
"source": [
"#@title ### 1.3 Install and import dependencies\n",
"\n",
"import pathlib, shutil\n",
"import pathlib, shutil, os, sys\n",
"\n",
"if not is_colab:\n",
" # If running locally, there's a good chance your env will need this in order to not crash upon np.matmul() or similar operations.\n",
@ -454,48 +460,70 @@
" root_path = f'/content'\n",
" model_path = '/content/models' \n",
"else:\n",
" root_path = f'.'\n",
" model_path = f'{root_path}/model'\n",
" root_path = os.getcwd()\n",
" model_path = f'{root_path}/models'\n",
"\n",
"model_256_downloaded = False\n",
"model_512_downloaded = False\n",
"model_secondary_downloaded = False\n",
"\n",
"multipip_res = subprocess.run(['pip', 'install', 'lpips', 'datetime', 'timm', 'ftfy', 'einops', 'pytorch-lightning', 'omegaconf'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
"print(multipip_res)\n",
"\n",
"if is_colab:\n",
" gitclone(\"https://github.com/openai/CLIP\")\n",
" #gitclone(\"https://github.com/facebookresearch/SLIP.git\")\n",
" gitclone(\"https://github.com/crowsonkb/guided-diffusion\")\n",
" gitclone(\"https://github.com/assafshocher/ResizeRight.git\")\n",
" gitclone(\"https://github.com/MSFTserver/pytorch3d-lite.git\")\n",
" pipie(\"./CLIP\")\n",
" pipie(\"./guided-diffusion\")\n",
" multipip_res = subprocess.run(['pip', 'install', 'lpips', 'datetime', 'timm', 'ftfy'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
" print(multipip_res)\n",
" subprocess.run(['apt', 'install', 'imagemagick'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
" gitclone(\"https://github.com/isl-org/MiDaS.git\")\n",
" gitclone(\"https://github.com/alembics/disco-diffusion.git\")\n",
" pipi(\"pytorch-lightning\")\n",
" pipi(\"omegaconf\")\n",
" pipi(\"einops\")\n",
" # Rename a file to avoid a name conflict..\n",
" try:\n",
" os.rename(\"MiDaS/utils.py\", \"MiDaS/midas_utils.py\")\n",
" shutil.copyfile(\"disco-diffusion/disco_xform_utils.py\", \"disco_xform_utils.py\")\n",
" except:\n",
" pass\n",
"\n",
"if not os.path.exists(f'{model_path}'):\n",
" pathlib.Path(model_path).mkdir(parents=True, exist_ok=True)\n",
"if not os.path.exists(f'{model_path}/dpt_large-midas-2f21e586.pt'):\n",
" wget(\"https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt\", model_path)\n",
"try:\n",
" import clip\n",
"except:\n",
" if os.path.exists(\"CLIP\") is not True:\n",
" gitclone(\"https://github.com/openai/CLIP\")\n",
" sys.path.append(f'{root_path}/CLIP')\n",
"\n",
"import sys\n",
"import torch\n",
"try:\n",
" from guided_diffusion.script_util import create_model_and_diffusion\n",
"except:\n",
" if os.path.exists(\"guided-diffusion\") is not True:\n",
" gitclone(\"https://github.com/crowsonkb/guided-diffusion\")\n",
" sys.path.append(f'{PROJECT_DIR}/guided-diffusion')\n",
"\n",
"try:\n",
" from resize_right import resize\n",
"except:\n",
" if os.path.exists(\"resize_right\") is not True:\n",
" gitclone(\"https://github.com/assafshocher/ResizeRight.git\")\n",
" sys.path.append(f'{PROJECT_DIR}/ResizeRight')\n",
"\n",
"try:\n",
" import py3d_tools\n",
"except:\n",
" if os.path.exists('pytorch3d-lite') is not True:\n",
" gitclone(\"https://github.com/MSFTserver/pytorch3d-lite.git\")\n",
" sys.path.append(f'{PROJECT_DIR}/pytorch3d-lite')\n",
"\n",
"try:\n",
" from midas.dpt_depth import DPTDepthModel\n",
"except:\n",
" if os.path.exists('MiDaS') is not True:\n",
" gitclone(\"https://github.com/isl-org/MiDaS.git\")\n",
" if os.path.exists('MiDaS/midas_utils.py') is not True:\n",
" shutil.move('MiDaS/utils.py', 'MiDaS/midas_utils.py')\n",
" if not os.path.exists(f'{model_path}/dpt_large-midas-2f21e586.pt'):\n",
" wget(\"https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt\", model_path)\n",
" sys.path.append(f'{PROJECT_DIR}/MiDaS')\n",
"\n",
"try:\n",
" sys.path.append(PROJECT_DIR)\n",
" import disco_xform_utils as dxf\n",
"except:\n",
" if os.path.exists(\"disco-diffusion\") is not True:\n",
" gitclone(\"https://github.com/alembics/disco-diffusion.git\")\n",
" # Rename a file to avoid a name conflict..\n",
" if os.path.exists('disco_xform_utils.py') is not True:\n",
" shutil.move('disco-diffusion/disco_xform_utils.py', 'disco_xform_utils.py')\n",
" sys.path.append(PROJECT_DIR)\n",
"\n",
"# sys.path.append('./SLIP')\n",
"sys.path.append('./pytorch3d-lite')\n",
"sys.path.append('./ResizeRight')\n",
"sys.path.append('./MiDaS')\n",
"import torch\n",
"from dataclasses import dataclass\n",
"from functools import partial\n",
"import cv2\n",
@ -516,11 +544,8 @@
"import torchvision.transforms as T\n",
"import torchvision.transforms.functional as TF\n",
"from tqdm.notebook import tqdm\n",
"sys.path.append('./CLIP')\n",
"sys.path.append('./guided-diffusion')\n",
"import clip\n",
"from resize_right import resize\n",
"# from models import SLIP_VITB16, SLIP, SLIP_VITL16\n",
"from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults\n",
"from datetime import datetime\n",
"import numpy as np\n",
@ -528,32 +553,7 @@
"import random\n",
"from ipywidgets import Output\n",
"import hashlib\n",
"\n",
"#SuperRes\n",
"if is_colab:\n",
" gitclone(\"https://github.com/CompVis/latent-diffusion.git\")\n",
" gitclone(\"https://github.com/CompVis/taming-transformers\")\n",
" pipie(\"./taming-transformers\")\n",
" pipi(\"ipywidgets omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops wandb\")\n",
"\n",
"#SuperRes\n",
"import ipywidgets as widgets\n",
"import os\n",
"sys.path.append(\".\")\n",
"sys.path.append('./taming-transformers')\n",
"from taming.models import vqgan # checking correct import from taming\n",
"from torchvision.datasets.utils import download_url\n",
"\n",
"if is_colab:\n",
" os.chdir('/content/latent-diffusion')\n",
"else:\n",
" #os.chdir('latent-diffusion')\n",
" sys.path.append('latent-diffusion')\n",
"from functools import partial\n",
"from ldm.util import instantiate_from_config\n",
"from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like\n",
"# from ldm.models.diffusion.ddim import DDIMSampler\n",
"from ldm.util import ismap\n",
"if is_colab:\n",
" os.chdir('/content')\n",
" from google.colab import files\n",
@ -570,13 +570,15 @@
"\n",
"# AdaBins stuff\n",
"if USE_ADABINS:\n",
" if is_colab:\n",
" gitclone(\"https://github.com/shariqfarooq123/AdaBins.git\")\n",
" if not os.path.exists(f'{model_path}/AdaBins_nyu.pt'):\n",
" wget(\"https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt\", model_path)\n",
" pathlib.Path(\"pretrained\").mkdir(parents=True, exist_ok=True)\n",
" shutil.copyfile(f\"{model_path}/AdaBins_nyu.pt\", \"pretrained/AdaBins_nyu.pt\")\n",
" sys.path.append('./AdaBins')\n",
" try:\n",
" from infer import InferenceHelper\n",
" except:\n",
" if os.path.exists(\"AdaBins\") is not True:\n",
" gitclone(\"https://github.com/shariqfarooq123/AdaBins.git\")\n",
" if not path_exists(f'{model_path}/pretrained/AdaBins_nyu.pt'):\n",
" os.makedirs(f'{model_path}/pretrained')\n",
" wget(\"https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt\", f'{model_path}/pretrained')\n",
" sys.path.append(f'{os.getcwd()}/AdaBins')\n",
" from infer import InferenceHelper\n",
" MAX_ADABINS_AREA = 500000\n",
"\n",
@ -1354,7 +1356,6 @@
" \n",
" # with run_display:\n",
" # display.clear_output(wait=True)\n",
" imgToSharpen = None\n",
" for j, sample in enumerate(samples): \n",
" cur_t -= 1\n",
" intermediateStep = False\n",
@ -1403,31 +1404,20 @@
" save_settings()\n",
" if args.animation_mode != \"None\":\n",
" image.save('prevFrame.png')\n",
" if args.sharpen_preset != \"Off\" and animation_mode == \"None\":\n",
" imgToSharpen = image\n",
" if args.keep_unsharp is True:\n",
" image.save(f'{unsharpenFolder}/{filename}')\n",
" else:\n",
" image.save(f'{batchFolder}/{filename}')\n",
" if args.animation_mode == \"3D\":\n",
" # If turbo, save a blended image\n",
" if turbo_mode and frame_num > 0:\n",
" # Mix new image with prevFrameScaled\n",
" blend_factor = (1)/int(turbo_steps)\n",
" newFrame = cv2.imread('prevFrame.png') # This is already updated..\n",
" prev_frame_warped = cv2.imread('prevFrameScaled.png')\n",
" blendedImage = cv2.addWeighted(newFrame, blend_factor, prev_frame_warped, (1-blend_factor), 0.0)\n",
" cv2.imwrite(f'{batchFolder}/{filename}',blendedImage)\n",
" else:\n",
" image.save(f'{batchFolder}/{filename}')\n",
" image.save(f'{batchFolder}/{filename}')\n",
" if args.animation_mode == \"3D\":\n",
" # If turbo, save a blended image\n",
" if turbo_mode and frame_num > 0:\n",
" # Mix new image with prevFrameScaled\n",
" blend_factor = (1)/int(turbo_steps)\n",
" newFrame = cv2.imread('prevFrame.png') # This is already updated..\n",
" prev_frame_warped = cv2.imread('prevFrameScaled.png')\n",
" blendedImage = cv2.addWeighted(newFrame, blend_factor, prev_frame_warped, (1-blend_factor), 0.0)\n",
" cv2.imwrite(f'{batchFolder}/{filename}',blendedImage)\n",
" else:\n",
" image.save(f'{batchFolder}/{filename}')\n",
" # if frame_num != args.max_frames-1:\n",
" # display.clear_output()\n",
"\n",
" with image_display: \n",
" if args.sharpen_preset != \"Off\" and animation_mode == \"None\":\n",
" print('Starting Diffusion Sharpening...')\n",
" do_superres(imgToSharpen, f'{batchFolder}/{filename}')\n",
" display.clear_output()\n",
" \n",
" plt.plot(np.array(loss_values), 'r')\n",
"\n",
@ -1685,539 +1675,6 @@
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {
"cellView": "form",
"id": "DefSuperRes"
},
"source": [
"#@title 1.7 SuperRes Define\n",
"class DDIMSampler(object):\n",
" def __init__(self, model, schedule=\"linear\", **kwargs):\n",
" super().__init__()\n",
" self.model = model\n",
" self.ddpm_num_timesteps = model.num_timesteps\n",
" self.schedule = schedule\n",
"\n",
" def register_buffer(self, name, attr):\n",
" if type(attr) == torch.Tensor:\n",
" if attr.device != torch.device(\"cuda\"):\n",
" attr = attr.to(torch.device(\"cuda\"))\n",
" setattr(self, name, attr)\n",
"\n",
" def make_schedule(self, ddim_num_steps, ddim_discretize=\"uniform\", ddim_eta=0., verbose=True):\n",
" self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,\n",
" num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)\n",
" alphas_cumprod = self.model.alphas_cumprod\n",
" assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'\n",
" to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)\n",
"\n",
" self.register_buffer('betas', to_torch(self.model.betas))\n",
" self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))\n",
" self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))\n",
"\n",
" # calculations for diffusion q(x_t | x_{t-1}) and others\n",
" self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))\n",
" self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))\n",
" self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))\n",
" self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))\n",
" self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))\n",
"\n",
" # ddim sampling parameters\n",
" ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),\n",
" ddim_timesteps=self.ddim_timesteps,\n",
" eta=ddim_eta,verbose=verbose)\n",
" self.register_buffer('ddim_sigmas', ddim_sigmas)\n",
" self.register_buffer('ddim_alphas', ddim_alphas)\n",
" self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)\n",
" self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))\n",
" sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(\n",
" (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (\n",
" 1 - self.alphas_cumprod / self.alphas_cumprod_prev))\n",
" self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)\n",
"\n",
" @torch.no_grad()\n",
" def sample(self,\n",
" S,\n",
" batch_size,\n",
" shape,\n",
" conditioning=None,\n",
" callback=None,\n",
" normals_sequence=None,\n",
" img_callback=None,\n",
" quantize_x0=False,\n",
" eta=0.,\n",
" mask=None,\n",
" x0=None,\n",
" temperature=1.,\n",
" noise_dropout=0.,\n",
" score_corrector=None,\n",
" corrector_kwargs=None,\n",
" verbose=True,\n",
" x_T=None,\n",
" log_every_t=100,\n",
" **kwargs\n",
" ):\n",
" if conditioning is not None:\n",
" if isinstance(conditioning, dict):\n",
" cbs = conditioning[list(conditioning.keys())[0]].shape[0]\n",
" if cbs != batch_size:\n",
" print(f\"Warning: Got {cbs} conditionings but batch-size is {batch_size}\")\n",
" else:\n",
" if conditioning.shape[0] != batch_size:\n",
" print(f\"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}\")\n",
"\n",
" self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)\n",
" # sampling\n",
" C, H, W = shape\n",
" size = (batch_size, C, H, W)\n",
" # print(f'Data shape for DDIM sampling is {size}, eta {eta}')\n",
"\n",
" samples, intermediates = self.ddim_sampling(conditioning, size,\n",
" callback=callback,\n",
" img_callback=img_callback,\n",
" quantize_denoised=quantize_x0,\n",
" mask=mask, x0=x0,\n",
" ddim_use_original_steps=False,\n",
" noise_dropout=noise_dropout,\n",
" temperature=temperature,\n",
" score_corrector=score_corrector,\n",
" corrector_kwargs=corrector_kwargs,\n",
" x_T=x_T,\n",
" log_every_t=log_every_t\n",
" )\n",
" return samples, intermediates\n",
"\n",
" @torch.no_grad()\n",
" def ddim_sampling(self, cond, shape,\n",
" x_T=None, ddim_use_original_steps=False,\n",
" callback=None, timesteps=None, quantize_denoised=False,\n",
" mask=None, x0=None, img_callback=None, log_every_t=100,\n",
" temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n",
" device = self.model.betas.device\n",
" b = shape[0]\n",
" if x_T is None:\n",
" img = torch.randn(shape, device=device)\n",
" else:\n",
" img = x_T\n",
"\n",
" if timesteps is None:\n",
" timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps\n",
" elif timesteps is not None and not ddim_use_original_steps:\n",
" subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1\n",
" timesteps = self.ddim_timesteps[:subset_end]\n",
"\n",
" intermediates = {'x_inter': [img], 'pred_x0': [img]}\n",
" time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)\n",
" total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]\n",
" print(f\"Running DDIM Sharpening with {total_steps} timesteps\")\n",
"\n",
" iterator = tqdm(time_range, desc='DDIM Sharpening', total=total_steps)\n",
"\n",
" for i, step in enumerate(iterator):\n",
" index = total_steps - i - 1\n",
" ts = torch.full((b,), step, device=device, dtype=torch.long)\n",
"\n",
" if mask is not None:\n",
" assert x0 is not None\n",
" img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?\n",
" img = img_orig * mask + (1. - mask) * img\n",
"\n",
" outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,\n",
" quantize_denoised=quantize_denoised, temperature=temperature,\n",
" noise_dropout=noise_dropout, score_corrector=score_corrector,\n",
" corrector_kwargs=corrector_kwargs)\n",
" img, pred_x0 = outs\n",
" if callback: callback(i)\n",
" if img_callback: img_callback(pred_x0, i)\n",
"\n",
" if index % log_every_t == 0 or index == total_steps - 1:\n",
" intermediates['x_inter'].append(img)\n",
" intermediates['pred_x0'].append(pred_x0)\n",
"\n",
" return img, intermediates\n",
"\n",
" @torch.no_grad()\n",
" def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,\n",
" temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n",
" b, *_, device = *x.shape, x.device\n",
" e_t = self.model.apply_model(x, t, c)\n",
" if score_corrector is not None:\n",
" assert self.model.parameterization == \"eps\"\n",
" e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)\n",
"\n",
" alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas\n",
" alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev\n",
" sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas\n",
" sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas\n",
" # select parameters corresponding to the currently considered timestep\n",
" a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)\n",
" a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)\n",
" sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)\n",
" sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)\n",
"\n",
" # current prediction for x_0\n",
" pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()\n",
" if quantize_denoised:\n",
" pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)\n",
" # direction pointing to x_t\n",
" dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t\n",
" noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature\n",
" if noise_dropout > 0.:\n",
" noise = torch.nn.functional.dropout(noise, p=noise_dropout)\n",
" x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise\n",
" return x_prev, pred_x0\n",
"\n",
"\n",
"def download_models(mode):\n",
"\n",
" if mode == \"superresolution\":\n",
" # this is the small bsr light model\n",
" url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1'\n",
" url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1'\n",
"\n",
" path_conf = f'{model_path}/superres/project.yaml'\n",
" path_ckpt = f'{model_path}/superres/last.ckpt'\n",
"\n",
" download_url(url_conf, path_conf)\n",
" download_url(url_ckpt, path_ckpt)\n",
"\n",
" path_conf = path_conf + '/?dl=1' # fix it\n",
" path_ckpt = path_ckpt + '/?dl=1' # fix it\n",
" return path_conf, path_ckpt\n",
"\n",
" else:\n",
" raise NotImplementedError\n",
"\n",
"\n",
"def load_model_from_config(config, ckpt):\n",
" print(f\"Loading model from {ckpt}\")\n",
" pl_sd = torch.load(ckpt, map_location=\"cpu\")\n",
" global_step = pl_sd[\"global_step\"]\n",
" sd = pl_sd[\"state_dict\"]\n",
" model = instantiate_from_config(config.model)\n",
" m, u = model.load_state_dict(sd, strict=False)\n",
" model.cuda()\n",
" model.eval()\n",
" return {\"model\": model}, global_step\n",
"\n",
"\n",
"def get_model(mode):\n",
" path_conf, path_ckpt = download_models(mode)\n",
" config = OmegaConf.load(path_conf)\n",
" model, step = load_model_from_config(config, path_ckpt)\n",
" return model\n",
"\n",
"\n",
"def get_custom_cond(mode):\n",
" dest = \"data/example_conditioning\"\n",
"\n",
" if mode == \"superresolution\":\n",
" uploaded_img = files.upload()\n",
" filename = next(iter(uploaded_img))\n",
" name, filetype = filename.split(\".\") # todo assumes just one dot in name !\n",
" os.rename(f\"{filename}\", f\"{dest}/{mode}/custom_{name}.{filetype}\")\n",
"\n",
" elif mode == \"text_conditional\":\n",
" w = widgets.Text(value='A cake with cream!', disabled=True)\n",
" display.display(w)\n",
"\n",
" with open(f\"{dest}/{mode}/custom_{w.value[:20]}.txt\", 'w') as f:\n",
" f.write(w.value)\n",
"\n",
" elif mode == \"class_conditional\":\n",
" w = widgets.IntSlider(min=0, max=1000)\n",
" display.display(w)\n",
" with open(f\"{dest}/{mode}/custom.txt\", 'w') as f:\n",
" f.write(w.value)\n",
"\n",
" else:\n",
" raise NotImplementedError(f\"cond not implemented for mode{mode}\")\n",
"\n",
"\n",
"def get_cond_options(mode):\n",
" path = \"data/example_conditioning\"\n",
" path = os.path.join(path, mode)\n",
" onlyfiles = [f for f in sorted(os.listdir(path))]\n",
" return path, onlyfiles\n",
"\n",
"\n",
"def select_cond_path(mode):\n",
" path = \"data/example_conditioning\" # todo\n",
" path = os.path.join(path, mode)\n",
" onlyfiles = [f for f in sorted(os.listdir(path))]\n",
"\n",
" selected = widgets.RadioButtons(\n",
" options=onlyfiles,\n",
" description='Select conditioning:',\n",
" disabled=False\n",
" )\n",
" display.display(selected)\n",
" selected_path = os.path.join(path, selected.value)\n",
" return selected_path\n",
"\n",
"\n",
"def get_cond(mode, img):\n",
" example = dict()\n",
" if mode == \"superresolution\":\n",
" up_f = 4\n",
" # visualize_cond_img(selected_path)\n",
"\n",
" c = img\n",
" c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)\n",
" c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True)\n",
" c_up = rearrange(c_up, '1 c h w -> 1 h w c')\n",
" c = rearrange(c, '1 c h w -> 1 h w c')\n",
" c = 2. * c - 1.\n",
"\n",
" c = c.to(torch.device(\"cuda\"))\n",
" example[\"LR_image\"] = c\n",
" example[\"image\"] = c_up\n",
"\n",
" return example\n",
"\n",
"\n",
"def visualize_cond_img(path):\n",
" display.display(ipyimg(filename=path))\n",
"\n",
"\n",
"def sr_run(model, img, task, custom_steps, eta, resize_enabled=False, classifier_ckpt=None, global_step=None):\n",
" # global stride\n",
"\n",
" example = get_cond(task, img)\n",
"\n",
" save_intermediate_vid = False\n",
" n_runs = 1\n",
" masked = False\n",
" guider = None\n",
" ckwargs = None\n",
" mode = 'ddim'\n",
" ddim_use_x0_pred = False\n",
" temperature = 1.\n",
" eta = eta\n",
" make_progrow = True\n",
" custom_shape = None\n",
"\n",
" height, width = example[\"image\"].shape[1:3]\n",
" split_input = height >= 128 and width >= 128\n",
"\n",
" if split_input:\n",
" ks = 128\n",
" stride = 64\n",
" vqf = 4 #\n",
" model.split_input_params = {\"ks\": (ks, ks), \"stride\": (stride, stride),\n",
" \"vqf\": vqf,\n",
" \"patch_distributed_vq\": True,\n",
" \"tie_braker\": False,\n",
" \"clip_max_weight\": 0.5,\n",
" \"clip_min_weight\": 0.01,\n",
" \"clip_max_tie_weight\": 0.5,\n",
" \"clip_min_tie_weight\": 0.01}\n",
" else:\n",
" if hasattr(model, \"split_input_params\"):\n",
" delattr(model, \"split_input_params\")\n",
"\n",
" invert_mask = False\n",
"\n",
" x_T = None\n",
" for n in range(n_runs):\n",
" if custom_shape is not None:\n",
" x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)\n",
" x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0])\n",
"\n",
" logs = make_convolutional_sample(example, model,\n",
" mode=mode, custom_steps=custom_steps,\n",
" eta=eta, swap_mode=False , masked=masked,\n",
" invert_mask=invert_mask, quantize_x0=False,\n",
" custom_schedule=None, decode_interval=10,\n",
" resize_enabled=resize_enabled, custom_shape=custom_shape,\n",
" temperature=temperature, noise_dropout=0.,\n",
" corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid,\n",
" make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred\n",
" )\n",
" return logs\n",
"\n",
"\n",
"@torch.no_grad()\n",
"def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,\n",
" mask=None, x0=None, quantize_x0=False, img_callback=None,\n",
" temperature=1., noise_dropout=0., score_corrector=None,\n",
" corrector_kwargs=None, x_T=None, log_every_t=None\n",
" ):\n",
"\n",
" ddim = DDIMSampler(model)\n",
" bs = shape[0] # dont know where this comes from but wayne\n",
" shape = shape[1:] # cut batch dim\n",
" # print(f\"Sampling with eta = {eta}; steps: {steps}\")\n",
" samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,\n",
" normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,\n",
" mask=mask, x0=x0, temperature=temperature, verbose=False,\n",
" score_corrector=score_corrector,\n",
" corrector_kwargs=corrector_kwargs, x_T=x_T)\n",
"\n",
" return samples, intermediates\n",
"\n",
"\n",
"@torch.no_grad()\n",
"def make_convolutional_sample(batch, model, mode=\"vanilla\", custom_steps=None, eta=1.0, swap_mode=False, masked=False,\n",
" invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000,\n",
" resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,\n",
" corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False):\n",
" log = dict()\n",
"\n",
" z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,\n",
" return_first_stage_outputs=True,\n",
" force_c_encode=not (hasattr(model, 'split_input_params')\n",
" and model.cond_stage_key == 'coordinates_bbox'),\n",
" return_original_cond=True)\n",
"\n",
" log_every_t = 1 if save_intermediate_vid else None\n",
"\n",
" if custom_shape is not None:\n",
" z = torch.randn(custom_shape)\n",
" # print(f\"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}\")\n",
"\n",
" z0 = None\n",
"\n",
" log[\"input\"] = x\n",
" log[\"reconstruction\"] = xrec\n",
"\n",
" if ismap(xc):\n",
" log[\"original_conditioning\"] = model.to_rgb(xc)\n",
" if hasattr(model, 'cond_stage_key'):\n",
" log[model.cond_stage_key] = model.to_rgb(xc)\n",
"\n",
" else:\n",
" log[\"original_conditioning\"] = xc if xc is not None else torch.zeros_like(x)\n",
" if model.cond_stage_model:\n",
" log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)\n",
" if model.cond_stage_key =='class_label':\n",
" log[model.cond_stage_key] = xc[model.cond_stage_key]\n",
"\n",
" with model.ema_scope(\"Plotting\"):\n",
" t0 = time.time()\n",
" img_cb = None\n",
"\n",
" sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,\n",
" eta=eta,\n",
" quantize_x0=quantize_x0, img_callback=img_cb, mask=None, x0=z0,\n",
" temperature=temperature, noise_dropout=noise_dropout,\n",
" score_corrector=corrector, corrector_kwargs=corrector_kwargs,\n",
" x_T=x_T, log_every_t=log_every_t)\n",
" t1 = time.time()\n",
"\n",
" if ddim_use_x0_pred:\n",
" sample = intermediates['pred_x0'][-1]\n",
"\n",
" x_sample = model.decode_first_stage(sample)\n",
"\n",
" try:\n",
" x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)\n",
" log[\"sample_noquant\"] = x_sample_noquant\n",
" log[\"sample_diff\"] = torch.abs(x_sample_noquant - x_sample)\n",
" except:\n",
" pass\n",
"\n",
" log[\"sample\"] = x_sample\n",
" log[\"time\"] = t1 - t0\n",
"\n",
" return log\n",
"\n",
"sr_diffMode = 'superresolution'\n",
"sr_model = get_model('superresolution')\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"def do_superres(img, filepath):\n",
"\n",
" if args.sharpen_preset == 'Faster':\n",
" sr_diffusion_steps = \"25\" \n",
" sr_pre_downsample = '1/2' \n",
" if args.sharpen_preset == 'Fast':\n",
" sr_diffusion_steps = \"100\" \n",
" sr_pre_downsample = '1/2' \n",
" if args.sharpen_preset == 'Slow':\n",
" sr_diffusion_steps = \"25\" \n",
" sr_pre_downsample = 'None' \n",
" if args.sharpen_preset == 'Very Slow':\n",
" sr_diffusion_steps = \"100\" \n",
" sr_pre_downsample = 'None' \n",
"\n",
"\n",
" sr_post_downsample = 'Original Size'\n",
" sr_diffusion_steps = int(sr_diffusion_steps)\n",
" sr_eta = 1.0 \n",
" sr_downsample_method = 'Lanczos' \n",
"\n",
" gc.collect()\n",
" torch.cuda.empty_cache()\n",
"\n",
" im_og = img\n",
" width_og, height_og = im_og.size\n",
"\n",
" #Downsample Pre\n",
" if sr_pre_downsample == '1/2':\n",
" downsample_rate = 2\n",
" elif sr_pre_downsample == '1/4':\n",
" downsample_rate = 4\n",
" else:\n",
" downsample_rate = 1\n",
"\n",
" width_downsampled_pre = width_og//downsample_rate\n",
" height_downsampled_pre = height_og//downsample_rate\n",
"\n",
" if downsample_rate != 1:\n",
" # print(f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')\n",
" im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)\n",
" # im_og.save('/content/temp.png')\n",
" # filepath = '/content/temp.png'\n",
"\n",
" logs = sr_run(sr_model[\"model\"], im_og, sr_diffMode, sr_diffusion_steps, sr_eta)\n",
"\n",
" sample = logs[\"sample\"]\n",
" sample = sample.detach().cpu()\n",
" sample = torch.clamp(sample, -1., 1.)\n",
" sample = (sample + 1.) / 2. * 255\n",
" sample = sample.numpy().astype(np.uint8)\n",
" sample = np.transpose(sample, (0, 2, 3, 1))\n",
" a = Image.fromarray(sample[0])\n",
"\n",
" #Downsample Post\n",
" if sr_post_downsample == '1/2':\n",
" downsample_rate = 2\n",
" elif sr_post_downsample == '1/4':\n",
" downsample_rate = 4\n",
" else:\n",
" downsample_rate = 1\n",
"\n",
" width, height = a.size\n",
" width_downsampled_post = width//downsample_rate\n",
" height_downsampled_post = height//downsample_rate\n",
"\n",
" if sr_downsample_method == 'Lanczos':\n",
" aliasing = Image.LANCZOS\n",
" else:\n",
" aliasing = Image.NEAREST\n",
"\n",
" if downsample_rate != 1:\n",
" # print(f'Downsampling from [{width}, {height}] to [{width_downsampled_post}, {height_downsampled_post}]')\n",
" a = a.resize((width_downsampled_post, height_downsampled_post), aliasing)\n",
" elif sr_post_downsample == 'Original Size':\n",
" # print(f'Downsampling from [{width}, {height}] to Original Size [{width_og}, {height_og}]')\n",
" a = a.resize((width_og, height_og), aliasing)\n",
"\n",
" display.display(a)\n",
" a.save(filepath)\n",
" return\n",
" print(f'Processing finished!')\n"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"metadata": {
@ -2248,8 +1705,6 @@
"RN50x4 = False #@param{type:\"boolean\"}\n",
"RN50x16 = False #@param{type:\"boolean\"}\n",
"RN50x64 = False #@param{type:\"boolean\"}\n",
"SLIPB16 = False #@param{type:\"boolean\"}\n",
"SLIPL16 = False #@param{type:\"boolean\"}\n",
"\n",
"#@markdown If you're having issues with model downloads, check this to compare SHA's:\n",
"check_model_SHA = False #@param{type:\"boolean\"}\n",
@ -2382,36 +1837,8 @@
"if RN50x64 is True: clip_models.append(clip.load('RN50x64', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
"if RN101 is True: clip_models.append(clip.load('RN101', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
"\n",
"if SLIPB16:\n",
" SLIPB16model = SLIP_VITB16(ssl_mlp_dim=4096, ssl_emb_dim=256)\n",
" if not os.path.exists(f'{model_path}/slip_base_100ep.pt'):\n",
" wget(\"https://dl.fbaipublicfiles.com/slip/slip_base_100ep.pt\", model_path)\n",
" sd = torch.load(f'{model_path}/slip_base_100ep.pt')\n",
" real_sd = {}\n",
" for k, v in sd['state_dict'].items():\n",
" real_sd['.'.join(k.split('.')[1:])] = v\n",
" del sd\n",
" SLIPB16model.load_state_dict(real_sd)\n",
" SLIPB16model.requires_grad_(False).eval().to(device)\n",
"\n",
" clip_models.append(SLIPB16model)\n",
"\n",
"if SLIPL16:\n",
" SLIPL16model = SLIP_VITL16(ssl_mlp_dim=4096, ssl_emb_dim=256)\n",
" if not os.path.exists(f'{model_path}/slip_large_100ep.pt'):\n",
" wget(\"https://dl.fbaipublicfiles.com/slip/slip_large_100ep.pt\", model_path)\n",
" sd = torch.load(f'{model_path}/slip_large_100ep.pt')\n",
" real_sd = {}\n",
" for k, v in sd['state_dict'].items():\n",
" real_sd['.'.join(k.split('.')[1:])] = v\n",
" del sd\n",
" SLIPL16model.load_state_dict(real_sd)\n",
" SLIPL16model.requires_grad_(False).eval().to(device)\n",
"\n",
" clip_models.append(SLIPL16model)\n",
"\n",
"normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])\n",
"lpips_model = lpips.LPIPS(net='vgg').to(device)\n"
"lpips_model = lpips.LPIPS(net='vgg').to(device)"
],
"outputs": [],
"execution_count": null
@ -2466,7 +1893,7 @@
"\n",
"#Make folder for batch\n",
"batchFolder = f'{outDirPath}/{batch_name}'\n",
"createPath(batchFolder)\n"
"createPath(batchFolder)"
],
"outputs": [],
"execution_count": null
@ -2812,7 +2239,7 @@
" translation_z = float(translation_z)\n",
" rotation_3d_x = float(rotation_3d_x)\n",
" rotation_3d_y = float(rotation_3d_y)\n",
" rotation_3d_z = float(rotation_3d_z)\n"
" rotation_3d_z = float(rotation_3d_z)"
],
"outputs": [],
"execution_count": null
@ -2824,7 +2251,7 @@
},
"source": [
"### Extra Settings\n",
" Partial Saves, Diffusion Sharpening, Advanced Settings, Cutn Scheduling"
" Partial Saves, Advanced Settings, Cutn Scheduling"
]
},
{
@ -2860,18 +2287,6 @@
"\n",
" #@markdown ---\n",
"\n",
"#@markdown ####**SuperRes Sharpening:**\n",
"#@markdown *Sharpen each image using latent-diffusion. Does not run in animation mode. `keep_unsharp` will save both versions.*\n",
"sharpen_preset = 'Off' #@param ['Off', 'Faster', 'Fast', 'Slow', 'Very Slow']\n",
"keep_unsharp = True #@param{type: 'boolean'}\n",
"\n",
"if sharpen_preset != 'Off' and keep_unsharp is True:\n",
" unsharpenFolder = f'{batchFolder}/unsharpened'\n",
" createPath(unsharpenFolder)\n",
"\n",
"\n",
" #@markdown ---\n",
"\n",
"#@markdown ####**Advanced Settings:**\n",
"#@markdown *There are a few extra advanced settings available if you double click this cell.*\n",
"\n",
@ -2902,7 +2317,7 @@
"cut_overview = \"[12]*400+[4]*600\" #@param {type: 'string'} \n",
"cut_innercut =\"[4]*400+[12]*600\"#@param {type: 'string'} \n",
"cut_ic_pow = 1#@param {type: 'number'} \n",
"cut_icgray_p = \"[0.2]*400+[0]*600\"#@param {type: 'string'}\n"
"cut_icgray_p = \"[0.2]*400+[0]*600\"#@param {type: 'string'}"
],
"outputs": [],
"execution_count": null
@ -2930,7 +2345,7 @@
"\n",
"image_prompts = {\n",
" # 0:['ImagePromptsWorkButArentVeryGood.png:2',],\n",
"}\n"
"}"
],
"outputs": [],
"execution_count": null
@ -3046,8 +2461,6 @@
" 'init_image': init_image,\n",
" 'init_scale': init_scale,\n",
" 'skip_steps': skip_steps,\n",
" 'sharpen_preset': sharpen_preset,\n",
" 'keep_unsharp': keep_unsharp,\n",
" 'side_x': side_x,\n",
" 'side_y': side_y,\n",
" 'timestep_respacing': timestep_respacing,\n",
@ -3130,7 +2543,7 @@
"finally:\n",
" print('Seed used:', seed)\n",
" gc.collect()\n",
" torch.cuda.empty_cache()\n"
" torch.cuda.empty_cache()"
],
"outputs": [],
"execution_count": null
@ -3224,7 +2637,7 @@
" # mp4 = open(filepath,'rb').read()\n",
" # data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
" # display.HTML(f'<video width=400 controls><source src=\"{data_url}\" type=\"video/mp4\"></video>')\n",
" \n"
" "
],
"outputs": [],
"execution_count": null

@ -45,6 +45,24 @@ A frankensteinian amalgamation of notebooks, models and techniques for the gener
#### v5 Update: Feb 20th 2022 - gandamu / Adam Letts
* Added 3D animation mode. Uses weighted combination of AdaBins and MiDaS depth estimation models. Uses pytorch3d for 3D transforms on Colab and/or Linux.
#### v5.1 Update: Mar 30th 2022 - zippy / Chris Allen and gandamu / Adam Letts
* Integrated Turbo+Smooth features from Disco Diffusion Turbo -- just the implementation, without its defaults.
* Implemented resume of turbo animations in such a way that it's now possible to resume from different batch folders and batch numbers.
* 3D rotation parameter units are now degrees (rather than radians)
* Corrected name collision in sampling_mode (now diffusion_sampling_mode for plms/ddim, and sampling_mode for 3D transform sampling)
* Added video_init_seed_continuity option to make init video animations more continuous
* Removed pytorch3d from needing to be compiled with a lite version specifically made for Disco Diffusion
* Remove Super Resolution
* Remove Slip Models
* Update for crossplatform support
#### v5.1 Update: Apr 4th 2022 - MSFTserver aka HostsServer
* Removed pytorch3d from needing to be compiled with a lite version specifically made for Disco Diffusion
* Remove Super Resolution
* Remove Slip Models
* Update for crossplatform support
## Notebook Provenance

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save