remove super resolution

fix image saving remove slip duplicate
remove super resolution
add meta data to ipynb
pull/25/head
MSFTserver 3 years ago
parent 018137e0e6
commit a1f25c79ec
  1. 633
      Disco_Diffusion.ipynb
  2. 10
      README.md
  3. 698
      disco.py

@ -241,7 +241,7 @@
"\n",
" IPython magic commands replaced by Python code\n",
"\n",
" v5.1 Update: Mar 30th 2022 - zippy / Chris Allen and gandamu / Adam Letts\n",
" v5.1 Update: Mar 30th 2022 - zippy / Chris Allen and gandamu / Adam Letts / MSFTserver aka HostsServer\n",
"\n",
" Integrated Turbo+Smooth features from Disco Diffusion Turbo -- just the implementation, without its defaults.\n",
"\n",
@ -253,8 +253,12 @@
"\n",
" Added video_init_seed_continuity option to make init video animations more continuous\n",
"\n",
" Removed pytorch3d from needing to be compiled with a lite version specifically made for Disco Diffusion\n",
"\n",
" Remove Super Resolution\n",
"\n",
" '''\n",
" )\n"
" )"
],
"outputs": [],
"execution_count": null
@ -528,32 +532,7 @@
"import random\n",
"from ipywidgets import Output\n",
"import hashlib\n",
"\n",
"#SuperRes\n",
"if is_colab:\n",
" gitclone(\"https://github.com/CompVis/latent-diffusion.git\")\n",
" gitclone(\"https://github.com/CompVis/taming-transformers\")\n",
" pipie(\"./taming-transformers\")\n",
" pipi(\"ipywidgets omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops wandb\")\n",
"\n",
"#SuperRes\n",
"import ipywidgets as widgets\n",
"import os\n",
"sys.path.append(\".\")\n",
"sys.path.append('./taming-transformers')\n",
"from taming.models import vqgan # checking correct import from taming\n",
"from torchvision.datasets.utils import download_url\n",
"\n",
"if is_colab:\n",
" os.chdir('/content/latent-diffusion')\n",
"else:\n",
" #os.chdir('latent-diffusion')\n",
" sys.path.append('latent-diffusion')\n",
"from functools import partial\n",
"from ldm.util import instantiate_from_config\n",
"from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like\n",
"# from ldm.models.diffusion.ddim import DDIMSampler\n",
"from ldm.util import ismap\n",
"if is_colab:\n",
" os.chdir('/content')\n",
" from google.colab import files\n",
@ -1356,7 +1335,6 @@
" \n",
" # with run_display:\n",
" # display.clear_output(wait=True)\n",
" imgToSharpen = None\n",
" for j, sample in enumerate(samples): \n",
" cur_t -= 1\n",
" intermediateStep = False\n",
@ -1405,11 +1383,6 @@
" save_settings()\n",
" if args.animation_mode != \"None\":\n",
" image.save('prevFrame.png')\n",
" if args.sharpen_preset != \"Off\" and animation_mode == \"None\":\n",
" imgToSharpen = image\n",
" if args.keep_unsharp is True:\n",
" image.save(f'{unsharpenFolder}/{filename}')\n",
" else:\n",
" image.save(f'{batchFolder}/{filename}')\n",
" if args.animation_mode == \"3D\":\n",
" # If turbo, save a blended image\n",
@ -1425,12 +1398,6 @@
" # if frame_num != args.max_frames-1:\n",
" # display.clear_output()\n",
" \n",
" with image_display: \n",
" if args.sharpen_preset != \"Off\" and animation_mode == \"None\":\n",
" print('Starting Diffusion Sharpening...')\n",
" do_superres(imgToSharpen, f'{batchFolder}/{filename}')\n",
" display.clear_output()\n",
" \n",
" plt.plot(np.array(loss_values), 'r')\n",
"\n",
"def save_settings():\n",
@ -1687,539 +1654,6 @@
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {
"cellView": "form",
"id": "DefSuperRes"
},
"source": [
"#@title 1.7 SuperRes Define\n",
"class DDIMSampler(object):\n",
" def __init__(self, model, schedule=\"linear\", **kwargs):\n",
" super().__init__()\n",
" self.model = model\n",
" self.ddpm_num_timesteps = model.num_timesteps\n",
" self.schedule = schedule\n",
"\n",
" def register_buffer(self, name, attr):\n",
" if type(attr) == torch.Tensor:\n",
" if attr.device != torch.device(\"cuda\"):\n",
" attr = attr.to(torch.device(\"cuda\"))\n",
" setattr(self, name, attr)\n",
"\n",
" def make_schedule(self, ddim_num_steps, ddim_discretize=\"uniform\", ddim_eta=0., verbose=True):\n",
" self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,\n",
" num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)\n",
" alphas_cumprod = self.model.alphas_cumprod\n",
" assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'\n",
" to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)\n",
"\n",
" self.register_buffer('betas', to_torch(self.model.betas))\n",
" self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))\n",
" self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))\n",
"\n",
" # calculations for diffusion q(x_t | x_{t-1}) and others\n",
" self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))\n",
" self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))\n",
" self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))\n",
" self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))\n",
" self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))\n",
"\n",
" # ddim sampling parameters\n",
" ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),\n",
" ddim_timesteps=self.ddim_timesteps,\n",
" eta=ddim_eta,verbose=verbose)\n",
" self.register_buffer('ddim_sigmas', ddim_sigmas)\n",
" self.register_buffer('ddim_alphas', ddim_alphas)\n",
" self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)\n",
" self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))\n",
" sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(\n",
" (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (\n",
" 1 - self.alphas_cumprod / self.alphas_cumprod_prev))\n",
" self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)\n",
"\n",
" @torch.no_grad()\n",
" def sample(self,\n",
" S,\n",
" batch_size,\n",
" shape,\n",
" conditioning=None,\n",
" callback=None,\n",
" normals_sequence=None,\n",
" img_callback=None,\n",
" quantize_x0=False,\n",
" eta=0.,\n",
" mask=None,\n",
" x0=None,\n",
" temperature=1.,\n",
" noise_dropout=0.,\n",
" score_corrector=None,\n",
" corrector_kwargs=None,\n",
" verbose=True,\n",
" x_T=None,\n",
" log_every_t=100,\n",
" **kwargs\n",
" ):\n",
" if conditioning is not None:\n",
" if isinstance(conditioning, dict):\n",
" cbs = conditioning[list(conditioning.keys())[0]].shape[0]\n",
" if cbs != batch_size:\n",
" print(f\"Warning: Got {cbs} conditionings but batch-size is {batch_size}\")\n",
" else:\n",
" if conditioning.shape[0] != batch_size:\n",
" print(f\"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}\")\n",
"\n",
" self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)\n",
" # sampling\n",
" C, H, W = shape\n",
" size = (batch_size, C, H, W)\n",
" # print(f'Data shape for DDIM sampling is {size}, eta {eta}')\n",
"\n",
" samples, intermediates = self.ddim_sampling(conditioning, size,\n",
" callback=callback,\n",
" img_callback=img_callback,\n",
" quantize_denoised=quantize_x0,\n",
" mask=mask, x0=x0,\n",
" ddim_use_original_steps=False,\n",
" noise_dropout=noise_dropout,\n",
" temperature=temperature,\n",
" score_corrector=score_corrector,\n",
" corrector_kwargs=corrector_kwargs,\n",
" x_T=x_T,\n",
" log_every_t=log_every_t\n",
" )\n",
" return samples, intermediates\n",
"\n",
" @torch.no_grad()\n",
" def ddim_sampling(self, cond, shape,\n",
" x_T=None, ddim_use_original_steps=False,\n",
" callback=None, timesteps=None, quantize_denoised=False,\n",
" mask=None, x0=None, img_callback=None, log_every_t=100,\n",
" temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n",
" device = self.model.betas.device\n",
" b = shape[0]\n",
" if x_T is None:\n",
" img = torch.randn(shape, device=device)\n",
" else:\n",
" img = x_T\n",
"\n",
" if timesteps is None:\n",
" timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps\n",
" elif timesteps is not None and not ddim_use_original_steps:\n",
" subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1\n",
" timesteps = self.ddim_timesteps[:subset_end]\n",
"\n",
" intermediates = {'x_inter': [img], 'pred_x0': [img]}\n",
" time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)\n",
" total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]\n",
" print(f\"Running DDIM Sharpening with {total_steps} timesteps\")\n",
"\n",
" iterator = tqdm(time_range, desc='DDIM Sharpening', total=total_steps)\n",
"\n",
" for i, step in enumerate(iterator):\n",
" index = total_steps - i - 1\n",
" ts = torch.full((b,), step, device=device, dtype=torch.long)\n",
"\n",
" if mask is not None:\n",
" assert x0 is not None\n",
" img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?\n",
" img = img_orig * mask + (1. - mask) * img\n",
"\n",
" outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,\n",
" quantize_denoised=quantize_denoised, temperature=temperature,\n",
" noise_dropout=noise_dropout, score_corrector=score_corrector,\n",
" corrector_kwargs=corrector_kwargs)\n",
" img, pred_x0 = outs\n",
" if callback: callback(i)\n",
" if img_callback: img_callback(pred_x0, i)\n",
"\n",
" if index % log_every_t == 0 or index == total_steps - 1:\n",
" intermediates['x_inter'].append(img)\n",
" intermediates['pred_x0'].append(pred_x0)\n",
"\n",
" return img, intermediates\n",
"\n",
" @torch.no_grad()\n",
" def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,\n",
" temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n",
" b, *_, device = *x.shape, x.device\n",
" e_t = self.model.apply_model(x, t, c)\n",
" if score_corrector is not None:\n",
" assert self.model.parameterization == \"eps\"\n",
" e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)\n",
"\n",
" alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas\n",
" alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev\n",
" sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas\n",
" sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas\n",
" # select parameters corresponding to the currently considered timestep\n",
" a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)\n",
" a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)\n",
" sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)\n",
" sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)\n",
"\n",
" # current prediction for x_0\n",
" pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()\n",
" if quantize_denoised:\n",
" pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)\n",
" # direction pointing to x_t\n",
" dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t\n",
" noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature\n",
" if noise_dropout > 0.:\n",
" noise = torch.nn.functional.dropout(noise, p=noise_dropout)\n",
" x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise\n",
" return x_prev, pred_x0\n",
"\n",
"\n",
"def download_models(mode):\n",
"\n",
" if mode == \"superresolution\":\n",
" # this is the small bsr light model\n",
" url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1'\n",
" url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1'\n",
"\n",
" path_conf = f'{model_path}/superres/project.yaml'\n",
" path_ckpt = f'{model_path}/superres/last.ckpt'\n",
"\n",
" download_url(url_conf, path_conf)\n",
" download_url(url_ckpt, path_ckpt)\n",
"\n",
" path_conf = path_conf + '/?dl=1' # fix it\n",
" path_ckpt = path_ckpt + '/?dl=1' # fix it\n",
" return path_conf, path_ckpt\n",
"\n",
" else:\n",
" raise NotImplementedError\n",
"\n",
"\n",
"def load_model_from_config(config, ckpt):\n",
" print(f\"Loading model from {ckpt}\")\n",
" pl_sd = torch.load(ckpt, map_location=\"cpu\")\n",
" global_step = pl_sd[\"global_step\"]\n",
" sd = pl_sd[\"state_dict\"]\n",
" model = instantiate_from_config(config.model)\n",
" m, u = model.load_state_dict(sd, strict=False)\n",
" model.cuda()\n",
" model.eval()\n",
" return {\"model\": model}, global_step\n",
"\n",
"\n",
"def get_model(mode):\n",
" path_conf, path_ckpt = download_models(mode)\n",
" config = OmegaConf.load(path_conf)\n",
" model, step = load_model_from_config(config, path_ckpt)\n",
" return model\n",
"\n",
"\n",
"def get_custom_cond(mode):\n",
" dest = \"data/example_conditioning\"\n",
"\n",
" if mode == \"superresolution\":\n",
" uploaded_img = files.upload()\n",
" filename = next(iter(uploaded_img))\n",
" name, filetype = filename.split(\".\") # todo assumes just one dot in name !\n",
" os.rename(f\"{filename}\", f\"{dest}/{mode}/custom_{name}.{filetype}\")\n",
"\n",
" elif mode == \"text_conditional\":\n",
" w = widgets.Text(value='A cake with cream!', disabled=True)\n",
" display.display(w)\n",
"\n",
" with open(f\"{dest}/{mode}/custom_{w.value[:20]}.txt\", 'w') as f:\n",
" f.write(w.value)\n",
"\n",
" elif mode == \"class_conditional\":\n",
" w = widgets.IntSlider(min=0, max=1000)\n",
" display.display(w)\n",
" with open(f\"{dest}/{mode}/custom.txt\", 'w') as f:\n",
" f.write(w.value)\n",
"\n",
" else:\n",
" raise NotImplementedError(f\"cond not implemented for mode{mode}\")\n",
"\n",
"\n",
"def get_cond_options(mode):\n",
" path = \"data/example_conditioning\"\n",
" path = os.path.join(path, mode)\n",
" onlyfiles = [f for f in sorted(os.listdir(path))]\n",
" return path, onlyfiles\n",
"\n",
"\n",
"def select_cond_path(mode):\n",
" path = \"data/example_conditioning\" # todo\n",
" path = os.path.join(path, mode)\n",
" onlyfiles = [f for f in sorted(os.listdir(path))]\n",
"\n",
" selected = widgets.RadioButtons(\n",
" options=onlyfiles,\n",
" description='Select conditioning:',\n",
" disabled=False\n",
" )\n",
" display.display(selected)\n",
" selected_path = os.path.join(path, selected.value)\n",
" return selected_path\n",
"\n",
"\n",
"def get_cond(mode, img):\n",
" example = dict()\n",
" if mode == \"superresolution\":\n",
" up_f = 4\n",
" # visualize_cond_img(selected_path)\n",
"\n",
" c = img\n",
" c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)\n",
" c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True)\n",
" c_up = rearrange(c_up, '1 c h w -> 1 h w c')\n",
" c = rearrange(c, '1 c h w -> 1 h w c')\n",
" c = 2. * c - 1.\n",
"\n",
" c = c.to(torch.device(\"cuda\"))\n",
" example[\"LR_image\"] = c\n",
" example[\"image\"] = c_up\n",
"\n",
" return example\n",
"\n",
"\n",
"def visualize_cond_img(path):\n",
" display.display(ipyimg(filename=path))\n",
"\n",
"\n",
"def sr_run(model, img, task, custom_steps, eta, resize_enabled=False, classifier_ckpt=None, global_step=None):\n",
" # global stride\n",
"\n",
" example = get_cond(task, img)\n",
"\n",
" save_intermediate_vid = False\n",
" n_runs = 1\n",
" masked = False\n",
" guider = None\n",
" ckwargs = None\n",
" mode = 'ddim'\n",
" ddim_use_x0_pred = False\n",
" temperature = 1.\n",
" eta = eta\n",
" make_progrow = True\n",
" custom_shape = None\n",
"\n",
" height, width = example[\"image\"].shape[1:3]\n",
" split_input = height >= 128 and width >= 128\n",
"\n",
" if split_input:\n",
" ks = 128\n",
" stride = 64\n",
" vqf = 4 #\n",
" model.split_input_params = {\"ks\": (ks, ks), \"stride\": (stride, stride),\n",
" \"vqf\": vqf,\n",
" \"patch_distributed_vq\": True,\n",
" \"tie_braker\": False,\n",
" \"clip_max_weight\": 0.5,\n",
" \"clip_min_weight\": 0.01,\n",
" \"clip_max_tie_weight\": 0.5,\n",
" \"clip_min_tie_weight\": 0.01}\n",
" else:\n",
" if hasattr(model, \"split_input_params\"):\n",
" delattr(model, \"split_input_params\")\n",
"\n",
" invert_mask = False\n",
"\n",
" x_T = None\n",
" for n in range(n_runs):\n",
" if custom_shape is not None:\n",
" x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)\n",
" x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0])\n",
"\n",
" logs = make_convolutional_sample(example, model,\n",
" mode=mode, custom_steps=custom_steps,\n",
" eta=eta, swap_mode=False , masked=masked,\n",
" invert_mask=invert_mask, quantize_x0=False,\n",
" custom_schedule=None, decode_interval=10,\n",
" resize_enabled=resize_enabled, custom_shape=custom_shape,\n",
" temperature=temperature, noise_dropout=0.,\n",
" corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid,\n",
" make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred\n",
" )\n",
" return logs\n",
"\n",
"\n",
"@torch.no_grad()\n",
"def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,\n",
" mask=None, x0=None, quantize_x0=False, img_callback=None,\n",
" temperature=1., noise_dropout=0., score_corrector=None,\n",
" corrector_kwargs=None, x_T=None, log_every_t=None\n",
" ):\n",
"\n",
" ddim = DDIMSampler(model)\n",
" bs = shape[0] # dont know where this comes from but wayne\n",
" shape = shape[1:] # cut batch dim\n",
" # print(f\"Sampling with eta = {eta}; steps: {steps}\")\n",
" samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,\n",
" normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,\n",
" mask=mask, x0=x0, temperature=temperature, verbose=False,\n",
" score_corrector=score_corrector,\n",
" corrector_kwargs=corrector_kwargs, x_T=x_T)\n",
"\n",
" return samples, intermediates\n",
"\n",
"\n",
"@torch.no_grad()\n",
"def make_convolutional_sample(batch, model, mode=\"vanilla\", custom_steps=None, eta=1.0, swap_mode=False, masked=False,\n",
" invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000,\n",
" resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,\n",
" corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False):\n",
" log = dict()\n",
"\n",
" z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,\n",
" return_first_stage_outputs=True,\n",
" force_c_encode=not (hasattr(model, 'split_input_params')\n",
" and model.cond_stage_key == 'coordinates_bbox'),\n",
" return_original_cond=True)\n",
"\n",
" log_every_t = 1 if save_intermediate_vid else None\n",
"\n",
" if custom_shape is not None:\n",
" z = torch.randn(custom_shape)\n",
" # print(f\"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}\")\n",
"\n",
" z0 = None\n",
"\n",
" log[\"input\"] = x\n",
" log[\"reconstruction\"] = xrec\n",
"\n",
" if ismap(xc):\n",
" log[\"original_conditioning\"] = model.to_rgb(xc)\n",
" if hasattr(model, 'cond_stage_key'):\n",
" log[model.cond_stage_key] = model.to_rgb(xc)\n",
"\n",
" else:\n",
" log[\"original_conditioning\"] = xc if xc is not None else torch.zeros_like(x)\n",
" if model.cond_stage_model:\n",
" log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)\n",
" if model.cond_stage_key =='class_label':\n",
" log[model.cond_stage_key] = xc[model.cond_stage_key]\n",
"\n",
" with model.ema_scope(\"Plotting\"):\n",
" t0 = time.time()\n",
" img_cb = None\n",
"\n",
" sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,\n",
" eta=eta,\n",
" quantize_x0=quantize_x0, img_callback=img_cb, mask=None, x0=z0,\n",
" temperature=temperature, noise_dropout=noise_dropout,\n",
" score_corrector=corrector, corrector_kwargs=corrector_kwargs,\n",
" x_T=x_T, log_every_t=log_every_t)\n",
" t1 = time.time()\n",
"\n",
" if ddim_use_x0_pred:\n",
" sample = intermediates['pred_x0'][-1]\n",
"\n",
" x_sample = model.decode_first_stage(sample)\n",
"\n",
" try:\n",
" x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)\n",
" log[\"sample_noquant\"] = x_sample_noquant\n",
" log[\"sample_diff\"] = torch.abs(x_sample_noquant - x_sample)\n",
" except:\n",
" pass\n",
"\n",
" log[\"sample\"] = x_sample\n",
" log[\"time\"] = t1 - t0\n",
"\n",
" return log\n",
"\n",
"sr_diffMode = 'superresolution'\n",
"sr_model = get_model('superresolution')\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"def do_superres(img, filepath):\n",
"\n",
" if args.sharpen_preset == 'Faster':\n",
" sr_diffusion_steps = \"25\" \n",
" sr_pre_downsample = '1/2' \n",
" if args.sharpen_preset == 'Fast':\n",
" sr_diffusion_steps = \"100\" \n",
" sr_pre_downsample = '1/2' \n",
" if args.sharpen_preset == 'Slow':\n",
" sr_diffusion_steps = \"25\" \n",
" sr_pre_downsample = 'None' \n",
" if args.sharpen_preset == 'Very Slow':\n",
" sr_diffusion_steps = \"100\" \n",
" sr_pre_downsample = 'None' \n",
"\n",
"\n",
" sr_post_downsample = 'Original Size'\n",
" sr_diffusion_steps = int(sr_diffusion_steps)\n",
" sr_eta = 1.0 \n",
" sr_downsample_method = 'Lanczos' \n",
"\n",
" gc.collect()\n",
" torch.cuda.empty_cache()\n",
"\n",
" im_og = img\n",
" width_og, height_og = im_og.size\n",
"\n",
" #Downsample Pre\n",
" if sr_pre_downsample == '1/2':\n",
" downsample_rate = 2\n",
" elif sr_pre_downsample == '1/4':\n",
" downsample_rate = 4\n",
" else:\n",
" downsample_rate = 1\n",
"\n",
" width_downsampled_pre = width_og//downsample_rate\n",
" height_downsampled_pre = height_og//downsample_rate\n",
"\n",
" if downsample_rate != 1:\n",
" # print(f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')\n",
" im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)\n",
" # im_og.save('/content/temp.png')\n",
" # filepath = '/content/temp.png'\n",
"\n",
" logs = sr_run(sr_model[\"model\"], im_og, sr_diffMode, sr_diffusion_steps, sr_eta)\n",
"\n",
" sample = logs[\"sample\"]\n",
" sample = sample.detach().cpu()\n",
" sample = torch.clamp(sample, -1., 1.)\n",
" sample = (sample + 1.) / 2. * 255\n",
" sample = sample.numpy().astype(np.uint8)\n",
" sample = np.transpose(sample, (0, 2, 3, 1))\n",
" a = Image.fromarray(sample[0])\n",
"\n",
" #Downsample Post\n",
" if sr_post_downsample == '1/2':\n",
" downsample_rate = 2\n",
" elif sr_post_downsample == '1/4':\n",
" downsample_rate = 4\n",
" else:\n",
" downsample_rate = 1\n",
"\n",
" width, height = a.size\n",
" width_downsampled_post = width//downsample_rate\n",
" height_downsampled_post = height//downsample_rate\n",
"\n",
" if sr_downsample_method == 'Lanczos':\n",
" aliasing = Image.LANCZOS\n",
" else:\n",
" aliasing = Image.NEAREST\n",
"\n",
" if downsample_rate != 1:\n",
" # print(f'Downsampling from [{width}, {height}] to [{width_downsampled_post}, {height_downsampled_post}]')\n",
" a = a.resize((width_downsampled_post, height_downsampled_post), aliasing)\n",
" elif sr_post_downsample == 'Original Size':\n",
" # print(f'Downsampling from [{width}, {height}] to Original Size [{width_og}, {height_og}]')\n",
" a = a.resize((width_og, height_og), aliasing)\n",
"\n",
" display.display(a)\n",
" a.save(filepath)\n",
" return\n",
" print(f'Processing finished!')\n"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"metadata": {
@ -2240,8 +1674,7 @@
"use_secondary_model = True #@param {type: 'boolean'}\n",
"diffusion_sampling_mode = 'ddim' #@param ['plms','ddim'] \n",
"\n",
"timestep_respacing = '250' #@param ['25','50','100','150','250','500','1000','ddim25','ddim50', 'ddim75', 'ddim100','ddim150','ddim250','ddim500','ddim1000'] \n",
"diffusion_steps = 1000 #@param {type: 'number'}\n",
"\n",
"use_checkpoint = True #@param {type: 'boolean'}\n",
"ViTB32 = True #@param{type:\"boolean\"}\n",
"ViTB16 = True #@param{type:\"boolean\"}\n",
@ -2333,9 +1766,9 @@
" model_config.update({\n",
" 'attention_resolutions': '32, 16, 8',\n",
" 'class_cond': False,\n",
" 'diffusion_steps': diffusion_steps,\n",
" 'diffusion_steps': 1000, #No need to edit this, it is taken care of later.\n",
" 'rescale_timesteps': True,\n",
" 'timestep_respacing': timestep_respacing,\n",
" 'timestep_respacing': 250, #No need to edit this, it is taken care of later.\n",
" 'image_size': 512,\n",
" 'learn_sigma': True,\n",
" 'noise_schedule': 'linear',\n",
@ -2351,9 +1784,9 @@
" model_config.update({\n",
" 'attention_resolutions': '32, 16, 8',\n",
" 'class_cond': False,\n",
" 'diffusion_steps': diffusion_steps,\n",
" 'diffusion_steps': 1000, #No need to edit this, it is taken care of later.\n",
" 'rescale_timesteps': True,\n",
" 'timestep_respacing': timestep_respacing,\n",
" 'timestep_respacing': 250, #No need to edit this, it is taken care of later.\n",
" 'image_size': 256,\n",
" 'learn_sigma': True,\n",
" 'noise_schedule': 'linear',\n",
@ -2398,24 +1831,10 @@
" SLIPB16model.load_state_dict(real_sd)\n",
" SLIPB16model.requires_grad_(False).eval().to(device)\n",
"\n",
" clip_models.append(SLIPB16model)\n",
"\n",
"if SLIPL16:\n",
" SLIPL16model = SLIP_VITL16(ssl_mlp_dim=4096, ssl_emb_dim=256)\n",
" if not os.path.exists(f'{model_path}/slip_large_100ep.pt'):\n",
" wget(\"https://dl.fbaipublicfiles.com/slip/slip_large_100ep.pt\", model_path)\n",
" sd = torch.load(f'{model_path}/slip_large_100ep.pt')\n",
" real_sd = {}\n",
" for k, v in sd['state_dict'].items():\n",
" real_sd['.'.join(k.split('.')[1:])] = v\n",
" del sd\n",
" SLIPL16model.load_state_dict(real_sd)\n",
" SLIPL16model.requires_grad_(False).eval().to(device)\n",
"\n",
" clip_models.append(SLIPL16model)\n",
"\n",
"normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])\n",
"lpips_model = lpips.LPIPS(net='vgg').to(device)\n"
"lpips_model = lpips.LPIPS(net='vgg').to(device)"
],
"outputs": [],
"execution_count": null
@ -2470,7 +1889,7 @@
"\n",
"#Make folder for batch\n",
"batchFolder = f'{outDirPath}/{batch_name}'\n",
"createPath(batchFolder)\n"
"createPath(batchFolder)"
],
"outputs": [],
"execution_count": null
@ -2816,7 +2235,7 @@
" translation_z = float(translation_z)\n",
" rotation_3d_x = float(rotation_3d_x)\n",
" rotation_3d_y = float(rotation_3d_y)\n",
" rotation_3d_z = float(rotation_3d_z)\n"
" rotation_3d_z = float(rotation_3d_z)"
],
"outputs": [],
"execution_count": null
@ -2828,7 +2247,7 @@
},
"source": [
"### Extra Settings\n",
" Partial Saves, Diffusion Sharpening, Advanced Settings, Cutn Scheduling"
" Partial Saves, Advanced Settings, Cutn Scheduling"
]
},
{
@ -2864,18 +2283,6 @@
"\n",
" #@markdown ---\n",
"\n",
"#@markdown ####**SuperRes Sharpening:**\n",
"#@markdown *Sharpen each image using latent-diffusion. Does not run in animation mode. `keep_unsharp` will save both versions.*\n",
"sharpen_preset = 'Off' #@param ['Off', 'Faster', 'Fast', 'Slow', 'Very Slow']\n",
"keep_unsharp = True #@param{type: 'boolean'}\n",
"\n",
"if sharpen_preset != 'Off' and keep_unsharp is True:\n",
" unsharpenFolder = f'{batchFolder}/unsharpened'\n",
" createPath(unsharpenFolder)\n",
"\n",
"\n",
" #@markdown ---\n",
"\n",
"#@markdown ####**Advanced Settings:**\n",
"#@markdown *There are a few extra advanced settings available if you double click this cell.*\n",
"\n",
@ -2906,7 +2313,7 @@
"cut_overview = \"[12]*400+[4]*600\" #@param {type: 'string'} \n",
"cut_innercut =\"[4]*400+[12]*600\"#@param {type: 'string'} \n",
"cut_ic_pow = 1#@param {type: 'number'} \n",
"cut_icgray_p = \"[0.2]*400+[0]*600\"#@param {type: 'string'}\n"
"cut_icgray_p = \"[0.2]*400+[0]*600\"#@param {type: 'string'}"
],
"outputs": [],
"execution_count": null
@ -2934,7 +2341,7 @@
"\n",
"image_prompts = {\n",
" # 0:['ImagePromptsWorkButArentVeryGood.png:2',],\n",
"}\n"
"}"
],
"outputs": [],
"execution_count": null
@ -3050,8 +2457,6 @@
" 'init_image': init_image,\n",
" 'init_scale': init_scale,\n",
" 'skip_steps': skip_steps,\n",
" 'sharpen_preset': sharpen_preset,\n",
" 'keep_unsharp': keep_unsharp,\n",
" 'side_x': side_x,\n",
" 'side_y': side_y,\n",
" 'timestep_respacing': timestep_respacing,\n",
@ -3134,7 +2539,7 @@
"finally:\n",
" print('Seed used:', seed)\n",
" gc.collect()\n",
" torch.cuda.empty_cache()\n"
" torch.cuda.empty_cache()"
],
"outputs": [],
"execution_count": null
@ -3228,7 +2633,7 @@
" # mp4 = open(filepath,'rb').read()\n",
" # data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
" # display.HTML(f'<video width=400 controls><source src=\"{data_url}\" type=\"video/mp4\"></video>')\n",
" \n"
" "
],
"outputs": [],
"execution_count": null

@ -45,6 +45,16 @@ A frankensteinian amalgamation of notebooks, models and techniques for the gener
#### v5 Update: Feb 20th 2022 - gandamu / Adam Letts
* Added 3D animation mode. Uses weighted combination of AdaBins and MiDaS depth estimation models. Uses pytorch3d for 3D transforms on Colab and/or Linux.
#### v5.1 Update: Mar 30th 2022 - zippy / Chris Allen and gandamu / Adam Letts / MSFTserver aka HostsServer
* Integrated Turbo+Smooth features from Disco Diffusion Turbo -- just the implementation, without its defaults.
* Implemented resume of turbo animations in such a way that it's now possible to resume from different batch folders and batch numbers.
* 3D rotation parameter units are now degrees (rather than radians)
* Corrected name collision in sampling_mode (now diffusion_sampling_mode for plms/ddim, and sampling_mode for 3D transform sampling)
* Added video_init_seed_continuity option to make init video animations more continuous
* Removed pytorch3d from needing to be compiled with a lite version specifically made for Disco Diffusion
* Remove Super Resolution
## Notebook Provenance

@ -1,9 +1,16 @@
# %%
# !! {"metadata": {
# !! "id": "view-in-github",
# !! "colab_type": "text"
# !! }}
"""
<a href="https://colab.research.google.com/github/alembics/disco-diffusion/blob/main/Disco_Diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
"""
# %%
# !! {"metadata": {
# !! "id": "TitleTop"
# !! }}
"""
# Disco Diffusion v5.1 - Now with Turbo
@ -13,11 +20,17 @@ For issues, join the [Disco Diffusion Discord](https://discord.gg/msEZBy4HxA) or
"""
# %%
# !! {"metadata": {
# !! "id": "CreditsChTop"
# !! }}
"""
### Credits & Changelog ⬇
"""
# %%
# !! {"metadata": {
# !! "id": "Credits"
# !! }}
"""
#### Credits
@ -45,11 +58,17 @@ Turbo feature by Chris Allen (https://twitter.com/zippy731)
"""
# %%
# !! {"metadata": {
# !! "id": "LicenseTop"
# !! }}
"""
#### License
"""
# %%
# !! {"metadata": {
# !! "id": "License"
# !! }}
"""
Licensed under the MIT License
@ -125,11 +144,18 @@ THE SOFTWARE.
"""
# %%
# !! {"metadata": {
# !! "id": "ChangelogTop"
# !! }}
"""
#### Changelog
"""
# %%
# !! {"metadata": {
# !! "cellView": "form",
# !! "id": "Changelog"
# !! }}
#@title <- View Changelog
skip_for_run_all = True #@param {type: 'boolean'}
@ -204,7 +230,7 @@ if skip_for_run_all == False:
IPython magic commands replaced by Python code
v5.1 Update: Mar 30th 2022 - zippy / Chris Allen and gandamu / Adam Letts
v5.1 Update: Mar 30th 2022 - zippy / Chris Allen and gandamu / Adam Letts / MSFTserver aka HostsServer
Integrated Turbo+Smooth features from Disco Diffusion Turbo -- just the implementation, without its defaults.
@ -216,16 +242,26 @@ if skip_for_run_all == False:
Added video_init_seed_continuity option to make init video animations more continuous
Removed pytorch3d from needing to be compiled with a lite version specifically made for Disco Diffusion
Remove Super Resolution
'''
)
# %%
# !! {"metadata": {
# !! "id": "TutorialTop"
# !! }}
"""
# Tutorial
"""
# %%
# !! {"metadata": {
# !! "id": "DiffusionSet"
# !! }}
"""
**Diffusion settings (Defaults are heavily outdated)**
---
@ -275,11 +311,18 @@ Setting | Description | Default
"""
# %%
# !! {"metadata": {
# !! "id": "SetupTop"
# !! }}
"""
# 1. Set Up
"""
# %%
# !! {"metadata": {
# !! "cellView": "form",
# !! "id": "CheckGPU"
# !! }}
#@title 1.1 Check GPU Status
import subprocess
simple_nvidia_smi_display = False#@param {type:"boolean"}
@ -295,6 +338,10 @@ else:
print(nvidiasmi_ecc_note)
# %%
# !! {"metadata": {
# !! "cellView": "form",
# !! "id": "PrepFolders"
# !! }}
#@title 1.2 Prepare Folders
import subprocess
import sys
@ -363,6 +410,10 @@ else:
# createPath(libraries)
# %%
# !! {"metadata": {
# !! "cellView": "form",
# !! "id": "InstallDeps"
# !! }}
#@title ### 1.3 Install and import dependencies
import pathlib, shutil
@ -453,32 +504,7 @@ import matplotlib.pyplot as plt
import random
from ipywidgets import Output
import hashlib
#SuperRes
if is_colab:
gitclone("https://github.com/CompVis/latent-diffusion.git")
gitclone("https://github.com/CompVis/taming-transformers")
pipie("./taming-transformers")
pipi("ipywidgets omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops wandb")
#SuperRes
import ipywidgets as widgets
import os
sys.path.append(".")
sys.path.append('./taming-transformers')
from taming.models import vqgan # checking correct import from taming
from torchvision.datasets.utils import download_url
if is_colab:
os.chdir('/content/latent-diffusion')
else:
#os.chdir('latent-diffusion')
sys.path.append('latent-diffusion')
from functools import partial
from ldm.util import instantiate_from_config
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
# from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import ismap
if is_colab:
os.chdir('/content')
from google.colab import files
@ -515,6 +541,10 @@ if torch.cuda.get_device_capability(DEVICE) == (8,0): ## A100 fix thanks to Emad
torch.backends.cudnn.enabled = False
# %%
# !! {"metadata": {
# !! "cellView": "form",
# !! "id": "DefMidasFns"
# !! }}
#@title ### 1.4 Define Midas functions
from midas.dpt_depth import DPTDepthModel
@ -618,6 +648,10 @@ def init_midas_depth_model(midas_model_type="dpt_large", optimize=True):
return midas_model, midas_transform, net_w, net_h, resize_mode, normalization
# %%
# !! {"metadata": {
# !! "cellView": "form",
# !! "id": "DefFns"
# !! }}
#@title 1.5 Define necessary functions
# https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869
@ -1263,7 +1297,6 @@ def do_run():
# with run_display:
# display.clear_output(wait=True)
imgToSharpen = None
for j, sample in enumerate(samples):
cur_t -= 1
intermediateStep = False
@ -1312,11 +1345,6 @@ def do_run():
save_settings()
if args.animation_mode != "None":
image.save('prevFrame.png')
if args.sharpen_preset != "Off" and animation_mode == "None":
imgToSharpen = image
if args.keep_unsharp is True:
image.save(f'{unsharpenFolder}/{filename}')
else:
image.save(f'{batchFolder}/{filename}')
if args.animation_mode == "3D":
# If turbo, save a blended image
@ -1332,12 +1360,6 @@ def do_run():
# if frame_num != args.max_frames-1:
# display.clear_output()
with image_display:
if args.sharpen_preset != "Off" and animation_mode == "None":
print('Starting Diffusion Sharpening...')
do_superres(imgToSharpen, f'{batchFolder}/{filename}')
display.clear_output()
plt.plot(np.array(loss_values), 'r')
def save_settings():
@ -1418,6 +1440,10 @@ def save_settings():
json.dump(setting_list, f, ensure_ascii=False, indent=4)
# %%
# !! {"metadata": {
# !! "cellView": "form",
# !! "id": "DefSecModel"
# !! }}
#@title 1.6 Define the secondary diffusion model
def append_dims(x, n):
@ -1582,537 +1608,19 @@ class SecondaryDiffusionImageNet2(nn.Module):
eps = input * sigmas + v * alphas
return DiffusionOutput(v, pred, eps)
# %%
#@title 1.7 SuperRes Define
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
# print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sharpening with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sharpening', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
b, *_, device = *x.shape, x.device
e_t = self.model.apply_model(x, t, c)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
def download_models(mode):
if mode == "superresolution":
# this is the small bsr light model
url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1'
url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1'
path_conf = f'{model_path}/superres/project.yaml'
path_ckpt = f'{model_path}/superres/last.ckpt'
download_url(url_conf, path_conf)
download_url(url_ckpt, path_ckpt)
path_conf = path_conf + '/?dl=1' # fix it
path_ckpt = path_ckpt + '/?dl=1' # fix it
return path_conf, path_ckpt
else:
raise NotImplementedError
def load_model_from_config(config, ckpt):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
global_step = pl_sd["global_step"]
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
model.cuda()
model.eval()
return {"model": model}, global_step
def get_model(mode):
path_conf, path_ckpt = download_models(mode)
config = OmegaConf.load(path_conf)
model, step = load_model_from_config(config, path_ckpt)
return model
def get_custom_cond(mode):
dest = "data/example_conditioning"
if mode == "superresolution":
uploaded_img = files.upload()
filename = next(iter(uploaded_img))
name, filetype = filename.split(".") # todo assumes just one dot in name !
os.rename(f"{filename}", f"{dest}/{mode}/custom_{name}.{filetype}")
elif mode == "text_conditional":
w = widgets.Text(value='A cake with cream!', disabled=True)
display.display(w)
with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", 'w') as f:
f.write(w.value)
elif mode == "class_conditional":
w = widgets.IntSlider(min=0, max=1000)
display.display(w)
with open(f"{dest}/{mode}/custom.txt", 'w') as f:
f.write(w.value)
else:
raise NotImplementedError(f"cond not implemented for mode{mode}")
def get_cond_options(mode):
path = "data/example_conditioning"
path = os.path.join(path, mode)
onlyfiles = [f for f in sorted(os.listdir(path))]
return path, onlyfiles
def select_cond_path(mode):
path = "data/example_conditioning" # todo
path = os.path.join(path, mode)
onlyfiles = [f for f in sorted(os.listdir(path))]
selected = widgets.RadioButtons(
options=onlyfiles,
description='Select conditioning:',
disabled=False
)
display.display(selected)
selected_path = os.path.join(path, selected.value)
return selected_path
def get_cond(mode, img):
example = dict()
if mode == "superresolution":
up_f = 4
# visualize_cond_img(selected_path)
c = img
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True)
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
c = rearrange(c, '1 c h w -> 1 h w c')
c = 2. * c - 1.
c = c.to(torch.device("cuda"))
example["LR_image"] = c
example["image"] = c_up
return example
def visualize_cond_img(path):
display.display(ipyimg(filename=path))
def sr_run(model, img, task, custom_steps, eta, resize_enabled=False, classifier_ckpt=None, global_step=None):
# global stride
example = get_cond(task, img)
save_intermediate_vid = False
n_runs = 1
masked = False
guider = None
ckwargs = None
mode = 'ddim'
ddim_use_x0_pred = False
temperature = 1.
eta = eta
make_progrow = True
custom_shape = None
height, width = example["image"].shape[1:3]
split_input = height >= 128 and width >= 128
if split_input:
ks = 128
stride = 64
vqf = 4 #
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
"vqf": vqf,
"patch_distributed_vq": True,
"tie_braker": False,
"clip_max_weight": 0.5,
"clip_min_weight": 0.01,
"clip_max_tie_weight": 0.5,
"clip_min_tie_weight": 0.01}
else:
if hasattr(model, "split_input_params"):
delattr(model, "split_input_params")
invert_mask = False
x_T = None
for n in range(n_runs):
if custom_shape is not None:
x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0])
logs = make_convolutional_sample(example, model,
mode=mode, custom_steps=custom_steps,
eta=eta, swap_mode=False , masked=masked,
invert_mask=invert_mask, quantize_x0=False,
custom_schedule=None, decode_interval=10,
resize_enabled=resize_enabled, custom_shape=custom_shape,
temperature=temperature, noise_dropout=0.,
corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid,
make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred
)
return logs
@torch.no_grad()
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
mask=None, x0=None, quantize_x0=False, img_callback=None,
temperature=1., noise_dropout=0., score_corrector=None,
corrector_kwargs=None, x_T=None, log_every_t=None
):
ddim = DDIMSampler(model)
bs = shape[0] # dont know where this comes from but wayne
shape = shape[1:] # cut batch dim
# print(f"Sampling with eta = {eta}; steps: {steps}")
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
mask=mask, x0=x0, temperature=temperature, verbose=False,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs, x_T=x_T)
return samples, intermediates
@torch.no_grad()
def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, eta=1.0, swap_mode=False, masked=False,
invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000,
resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False):
log = dict()
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=not (hasattr(model, 'split_input_params')
and model.cond_stage_key == 'coordinates_bbox'),
return_original_cond=True)
log_every_t = 1 if save_intermediate_vid else None
if custom_shape is not None:
z = torch.randn(custom_shape)
# print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
z0 = None
log["input"] = x
log["reconstruction"] = xrec
if ismap(xc):
log["original_conditioning"] = model.to_rgb(xc)
if hasattr(model, 'cond_stage_key'):
log[model.cond_stage_key] = model.to_rgb(xc)
else:
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_model:
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_key =='class_label':
log[model.cond_stage_key] = xc[model.cond_stage_key]
with model.ema_scope("Plotting"):
t0 = time.time()
img_cb = None
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
eta=eta,
quantize_x0=quantize_x0, img_callback=img_cb, mask=None, x0=z0,
temperature=temperature, noise_dropout=noise_dropout,
score_corrector=corrector, corrector_kwargs=corrector_kwargs,
x_T=x_T, log_every_t=log_every_t)
t1 = time.time()
if ddim_use_x0_pred:
sample = intermediates['pred_x0'][-1]
x_sample = model.decode_first_stage(sample)
try:
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
log["sample_noquant"] = x_sample_noquant
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
except:
pass
log["sample"] = x_sample
log["time"] = t1 - t0
return log
sr_diffMode = 'superresolution'
sr_model = get_model('superresolution')
def do_superres(img, filepath):
if args.sharpen_preset == 'Faster':
sr_diffusion_steps = "25"
sr_pre_downsample = '1/2'
if args.sharpen_preset == 'Fast':
sr_diffusion_steps = "100"
sr_pre_downsample = '1/2'
if args.sharpen_preset == 'Slow':
sr_diffusion_steps = "25"
sr_pre_downsample = 'None'
if args.sharpen_preset == 'Very Slow':
sr_diffusion_steps = "100"
sr_pre_downsample = 'None'
sr_post_downsample = 'Original Size'
sr_diffusion_steps = int(sr_diffusion_steps)
sr_eta = 1.0
sr_downsample_method = 'Lanczos'
gc.collect()
torch.cuda.empty_cache()
im_og = img
width_og, height_og = im_og.size
#Downsample Pre
if sr_pre_downsample == '1/2':
downsample_rate = 2
elif sr_pre_downsample == '1/4':
downsample_rate = 4
else:
downsample_rate = 1
width_downsampled_pre = width_og//downsample_rate
height_downsampled_pre = height_og//downsample_rate
if downsample_rate != 1:
# print(f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
# im_og.save('/content/temp.png')
# filepath = '/content/temp.png'
logs = sr_run(sr_model["model"], im_og, sr_diffMode, sr_diffusion_steps, sr_eta)
sample = logs["sample"]
sample = sample.detach().cpu()
sample = torch.clamp(sample, -1., 1.)
sample = (sample + 1.) / 2. * 255
sample = sample.numpy().astype(np.uint8)
sample = np.transpose(sample, (0, 2, 3, 1))
a = Image.fromarray(sample[0])
#Downsample Post
if sr_post_downsample == '1/2':
downsample_rate = 2
elif sr_post_downsample == '1/4':
downsample_rate = 4
else:
downsample_rate = 1
width, height = a.size
width_downsampled_post = width//downsample_rate
height_downsampled_post = height//downsample_rate
if sr_downsample_method == 'Lanczos':
aliasing = Image.LANCZOS
else:
aliasing = Image.NEAREST
if downsample_rate != 1:
# print(f'Downsampling from [{width}, {height}] to [{width_downsampled_post}, {height_downsampled_post}]')
a = a.resize((width_downsampled_post, height_downsampled_post), aliasing)
elif sr_post_downsample == 'Original Size':
# print(f'Downsampling from [{width}, {height}] to Original Size [{width_og}, {height_og}]')
a = a.resize((width_og, height_og), aliasing)
display.display(a)
a.save(filepath)
return
print(f'Processing finished!')
# %%
# !! {"metadata": {
# !! "id": "DiffClipSetTop"
# !! }}
"""
# 2. Diffusion and CLIP model settings
"""
# %%
# !! {"metadata": {
# !! "id": "ModelSettings"
# !! }}
#@markdown ####**Models Settings:**
diffusion_model = "512x512_diffusion_uncond_finetune_008100" #@param ["256x256_diffusion_uncond", "512x512_diffusion_uncond_finetune_008100"]
use_secondary_model = True #@param {type: 'boolean'}
@ -2275,20 +1783,6 @@ if SLIPB16:
SLIPB16model.load_state_dict(real_sd)
SLIPB16model.requires_grad_(False).eval().to(device)
clip_models.append(SLIPB16model)
if SLIPL16:
SLIPL16model = SLIP_VITL16(ssl_mlp_dim=4096, ssl_emb_dim=256)
if not os.path.exists(f'{model_path}/slip_large_100ep.pt'):
wget("https://dl.fbaipublicfiles.com/slip/slip_large_100ep.pt", model_path)
sd = torch.load(f'{model_path}/slip_large_100ep.pt')
real_sd = {}
for k, v in sd['state_dict'].items():
real_sd['.'.join(k.split('.')[1:])] = v
del sd
SLIPL16model.load_state_dict(real_sd)
SLIPL16model.requires_grad_(False).eval().to(device)
clip_models.append(SLIPL16model)
normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
@ -2296,11 +1790,17 @@ lpips_model = lpips.LPIPS(net='vgg').to(device)
# %%
# !! {"metadata": {
# !! "id": "SettingsTop"
# !! }}
"""
# 3. Settings
"""
# %%
# !! {"metadata": {
# !! "id": "BasicSettings"
# !! }}
#@markdown ####**Basic Settings:**
batch_name = 'TimeToDisco' #@param{type: 'string'}
steps = 250 #@param [25,50,100,150,250,500,1000]{type: 'raw', allow-input: true}
@ -2340,11 +1840,17 @@ createPath(batchFolder)
# %%
# !! {"metadata": {
# !! "id": "AnimSetTop"
# !! }}
"""
### Animation Settings
"""
# %%
# !! {"metadata": {
# !! "id": "AnimSettings"
# !! }}
#@markdown ####**Animation Mode:**
animation_mode = 'None' #@param ['None', '2D', '3D', 'Video Input'] {type:'string'}
#@markdown *For animation, you probably want to turn `cutn_batches` to 1 to make it quicker.*
@ -2675,12 +2181,18 @@ else:
# %%
# !! {"metadata": {
# !! "id": "ExtraSetTop"
# !! }}
"""
### Extra Settings
Partial Saves, Diffusion Sharpening, Advanced Settings, Cutn Scheduling
Partial Saves, Advanced Settings, Cutn Scheduling
"""
# %%
# !! {"metadata": {
# !! "id": "ExtraSettings"
# !! }}
#@markdown ####**Saving:**
intermediate_saves = 0#@param{type: 'raw'}
@ -2706,18 +2218,6 @@ if intermediate_saves and intermediates_in_subfolder is True:
partialFolder = f'{batchFolder}/partials'
createPath(partialFolder)
#@markdown ---
#@markdown ####**SuperRes Sharpening:**
#@markdown *Sharpen each image using latent-diffusion. Does not run in animation mode. `keep_unsharp` will save both versions.*
sharpen_preset = 'Off' #@param ['Off', 'Faster', 'Fast', 'Slow', 'Very Slow']
keep_unsharp = True #@param{type: 'boolean'}
if sharpen_preset != 'Off' and keep_unsharp is True:
unsharpenFolder = f'{batchFolder}/unsharpened'
createPath(unsharpenFolder)
#@markdown ---
#@markdown ####**Advanced Settings:**
@ -2754,12 +2254,18 @@ cut_icgray_p = "[0.2]*400+[0]*600"#@param {type: 'string'}
# %%
# !! {"metadata": {
# !! "id": "PromptsTop"
# !! }}
"""
### Prompts
`animation_mode: None` will only use the first set. `animation_mode: 2D / Video` will run through them per the set frames and hold on the last one.
"""
# %%
# !! {"metadata": {
# !! "id": "Prompts"
# !! }}
text_prompts = {
0: ["A beautiful painting of a singular lighthouse, shining its light across a tumultuous sea of blood by greg rutkowski and thomas kinkade, Trending on artstation.", "yellow color scheme"],
100: ["This set of prompts start at frame 100","This prompt has weight five:5"],
@ -2771,11 +2277,17 @@ image_prompts = {
# %%
# !! {"metadata": {
# !! "id": "DiffuseTop"
# !! }}
"""
# 4. Diffuse!
"""
# %%
# !! {"metadata": {
# !! "id": "DoTheRun"
# !! }}
#@title Do the Run!
#@markdown `n_batches` ignored with animation modes.
display_rate = 50 #@param{type: 'number'}
@ -2872,8 +2384,6 @@ args = {
'init_image': init_image,
'init_scale': init_scale,
'skip_steps': skip_steps,
'sharpen_preset': sharpen_preset,
'keep_unsharp': keep_unsharp,
'side_x': side_x,
'side_y': side_y,
'timestep_respacing': timestep_respacing,
@ -2960,11 +2470,17 @@ finally:
# %%
# !! {"metadata": {
# !! "id": "CreateVidTop"
# !! }}
"""
# 5. Create the video
"""
# %%
# !! {"metadata": {
# !! "id": "CreateVid"
# !! }}
# @title ### **Create video**
#@markdown Video file will save in the same folder as your images.

Loading…
Cancel
Save