diff --git a/Disco_Diffusion.ipynb b/Disco_Diffusion.ipynb index df48f19..97fea4a 100644 --- a/Disco_Diffusion.ipynb +++ b/Disco_Diffusion.ipynb @@ -241,7 +241,7 @@ "\n", " IPython magic commands replaced by Python code\n", "\n", - " v5.1 Update: Mar 30th 2022 - zippy / Chris Allen and gandamu / Adam Letts\n", + " v5.1 Update: Mar 30th 2022 - zippy / Chris Allen and gandamu / Adam Letts / MSFTserver aka HostsServer\n", "\n", " Integrated Turbo+Smooth features from Disco Diffusion Turbo -- just the implementation, without its defaults.\n", "\n", @@ -253,8 +253,12 @@ "\n", " Added video_init_seed_continuity option to make init video animations more continuous\n", "\n", + " Removed pytorch3d from needing to be compiled with a lite version specifically made for Disco Diffusion\n", + "\n", + " Remove Super Resolution\n", + "\n", " '''\n", - " )\n" + " )" ], "outputs": [], "execution_count": null @@ -528,32 +532,7 @@ "import random\n", "from ipywidgets import Output\n", "import hashlib\n", - "\n", - "#SuperRes\n", - "if is_colab:\n", - " gitclone(\"https://github.com/CompVis/latent-diffusion.git\")\n", - " gitclone(\"https://github.com/CompVis/taming-transformers\")\n", - " pipie(\"./taming-transformers\")\n", - " pipi(\"ipywidgets omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops wandb\")\n", - "\n", - "#SuperRes\n", - "import ipywidgets as widgets\n", - "import os\n", - "sys.path.append(\".\")\n", - "sys.path.append('./taming-transformers')\n", - "from taming.models import vqgan # checking correct import from taming\n", - "from torchvision.datasets.utils import download_url\n", - "\n", - "if is_colab:\n", - " os.chdir('/content/latent-diffusion')\n", - "else:\n", - " #os.chdir('latent-diffusion')\n", - " sys.path.append('latent-diffusion')\n", "from functools import partial\n", - "from ldm.util import instantiate_from_config\n", - "from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like\n", - "# from ldm.models.diffusion.ddim import DDIMSampler\n", - "from ldm.util import ismap\n", "if is_colab:\n", " os.chdir('/content')\n", " from google.colab import files\n", @@ -1356,7 +1335,6 @@ " \n", " # with run_display:\n", " # display.clear_output(wait=True)\n", - " imgToSharpen = None\n", " for j, sample in enumerate(samples): \n", " cur_t -= 1\n", " intermediateStep = False\n", @@ -1405,31 +1383,20 @@ " save_settings()\n", " if args.animation_mode != \"None\":\n", " image.save('prevFrame.png')\n", - " if args.sharpen_preset != \"Off\" and animation_mode == \"None\":\n", - " imgToSharpen = image\n", - " if args.keep_unsharp is True:\n", - " image.save(f'{unsharpenFolder}/{filename}')\n", - " else:\n", - " image.save(f'{batchFolder}/{filename}')\n", - " if args.animation_mode == \"3D\":\n", - " # If turbo, save a blended image\n", - " if turbo_mode:\n", - " # Mix new image with prevFrameScaled\n", - " blend_factor = (1)/int(turbo_steps)\n", - " newFrame = cv2.imread('prevFrame.png') # This is already updated..\n", - " prev_frame_warped = cv2.imread('prevFrameScaled.png')\n", - " blendedImage = cv2.addWeighted(newFrame, blend_factor, prev_frame_warped, (1-blend_factor), 0.0)\n", - " cv2.imwrite(f'{batchFolder}/{filename}',blendedImage)\n", - " else:\n", - " image.save(f'{batchFolder}/{filename}')\n", + " image.save(f'{batchFolder}/{filename}')\n", + " if args.animation_mode == \"3D\":\n", + " # If turbo, save a blended image\n", + " if turbo_mode:\n", + " # Mix new image with prevFrameScaled\n", + " blend_factor = (1)/int(turbo_steps)\n", + " newFrame = cv2.imread('prevFrame.png') # This is already updated..\n", + " prev_frame_warped = cv2.imread('prevFrameScaled.png')\n", + " blendedImage = cv2.addWeighted(newFrame, blend_factor, prev_frame_warped, (1-blend_factor), 0.0)\n", + " cv2.imwrite(f'{batchFolder}/{filename}',blendedImage)\n", + " else:\n", + " image.save(f'{batchFolder}/{filename}')\n", " # if frame_num != args.max_frames-1:\n", " # display.clear_output()\n", - "\n", - " with image_display: \n", - " if args.sharpen_preset != \"Off\" and animation_mode == \"None\":\n", - " print('Starting Diffusion Sharpening...')\n", - " do_superres(imgToSharpen, f'{batchFolder}/{filename}')\n", - " display.clear_output()\n", " \n", " plt.plot(np.array(loss_values), 'r')\n", "\n", @@ -1687,539 +1654,6 @@ "outputs": [], "execution_count": null }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "DefSuperRes" - }, - "source": [ - "#@title 1.7 SuperRes Define\n", - "class DDIMSampler(object):\n", - " def __init__(self, model, schedule=\"linear\", **kwargs):\n", - " super().__init__()\n", - " self.model = model\n", - " self.ddpm_num_timesteps = model.num_timesteps\n", - " self.schedule = schedule\n", - "\n", - " def register_buffer(self, name, attr):\n", - " if type(attr) == torch.Tensor:\n", - " if attr.device != torch.device(\"cuda\"):\n", - " attr = attr.to(torch.device(\"cuda\"))\n", - " setattr(self, name, attr)\n", - "\n", - " def make_schedule(self, ddim_num_steps, ddim_discretize=\"uniform\", ddim_eta=0., verbose=True):\n", - " self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,\n", - " num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)\n", - " alphas_cumprod = self.model.alphas_cumprod\n", - " assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'\n", - " to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)\n", - "\n", - " self.register_buffer('betas', to_torch(self.model.betas))\n", - " self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))\n", - " self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))\n", - "\n", - " # calculations for diffusion q(x_t | x_{t-1}) and others\n", - " self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))\n", - " self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))\n", - " self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))\n", - " self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))\n", - " self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))\n", - "\n", - " # ddim sampling parameters\n", - " ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),\n", - " ddim_timesteps=self.ddim_timesteps,\n", - " eta=ddim_eta,verbose=verbose)\n", - " self.register_buffer('ddim_sigmas', ddim_sigmas)\n", - " self.register_buffer('ddim_alphas', ddim_alphas)\n", - " self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)\n", - " self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))\n", - " sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(\n", - " (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (\n", - " 1 - self.alphas_cumprod / self.alphas_cumprod_prev))\n", - " self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)\n", - "\n", - " @torch.no_grad()\n", - " def sample(self,\n", - " S,\n", - " batch_size,\n", - " shape,\n", - " conditioning=None,\n", - " callback=None,\n", - " normals_sequence=None,\n", - " img_callback=None,\n", - " quantize_x0=False,\n", - " eta=0.,\n", - " mask=None,\n", - " x0=None,\n", - " temperature=1.,\n", - " noise_dropout=0.,\n", - " score_corrector=None,\n", - " corrector_kwargs=None,\n", - " verbose=True,\n", - " x_T=None,\n", - " log_every_t=100,\n", - " **kwargs\n", - " ):\n", - " if conditioning is not None:\n", - " if isinstance(conditioning, dict):\n", - " cbs = conditioning[list(conditioning.keys())[0]].shape[0]\n", - " if cbs != batch_size:\n", - " print(f\"Warning: Got {cbs} conditionings but batch-size is {batch_size}\")\n", - " else:\n", - " if conditioning.shape[0] != batch_size:\n", - " print(f\"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}\")\n", - "\n", - " self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)\n", - " # sampling\n", - " C, H, W = shape\n", - " size = (batch_size, C, H, W)\n", - " # print(f'Data shape for DDIM sampling is {size}, eta {eta}')\n", - "\n", - " samples, intermediates = self.ddim_sampling(conditioning, size,\n", - " callback=callback,\n", - " img_callback=img_callback,\n", - " quantize_denoised=quantize_x0,\n", - " mask=mask, x0=x0,\n", - " ddim_use_original_steps=False,\n", - " noise_dropout=noise_dropout,\n", - " temperature=temperature,\n", - " score_corrector=score_corrector,\n", - " corrector_kwargs=corrector_kwargs,\n", - " x_T=x_T,\n", - " log_every_t=log_every_t\n", - " )\n", - " return samples, intermediates\n", - "\n", - " @torch.no_grad()\n", - " def ddim_sampling(self, cond, shape,\n", - " x_T=None, ddim_use_original_steps=False,\n", - " callback=None, timesteps=None, quantize_denoised=False,\n", - " mask=None, x0=None, img_callback=None, log_every_t=100,\n", - " temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n", - " device = self.model.betas.device\n", - " b = shape[0]\n", - " if x_T is None:\n", - " img = torch.randn(shape, device=device)\n", - " else:\n", - " img = x_T\n", - "\n", - " if timesteps is None:\n", - " timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps\n", - " elif timesteps is not None and not ddim_use_original_steps:\n", - " subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1\n", - " timesteps = self.ddim_timesteps[:subset_end]\n", - "\n", - " intermediates = {'x_inter': [img], 'pred_x0': [img]}\n", - " time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)\n", - " total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]\n", - " print(f\"Running DDIM Sharpening with {total_steps} timesteps\")\n", - "\n", - " iterator = tqdm(time_range, desc='DDIM Sharpening', total=total_steps)\n", - "\n", - " for i, step in enumerate(iterator):\n", - " index = total_steps - i - 1\n", - " ts = torch.full((b,), step, device=device, dtype=torch.long)\n", - "\n", - " if mask is not None:\n", - " assert x0 is not None\n", - " img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?\n", - " img = img_orig * mask + (1. - mask) * img\n", - "\n", - " outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,\n", - " quantize_denoised=quantize_denoised, temperature=temperature,\n", - " noise_dropout=noise_dropout, score_corrector=score_corrector,\n", - " corrector_kwargs=corrector_kwargs)\n", - " img, pred_x0 = outs\n", - " if callback: callback(i)\n", - " if img_callback: img_callback(pred_x0, i)\n", - "\n", - " if index % log_every_t == 0 or index == total_steps - 1:\n", - " intermediates['x_inter'].append(img)\n", - " intermediates['pred_x0'].append(pred_x0)\n", - "\n", - " return img, intermediates\n", - "\n", - " @torch.no_grad()\n", - " def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,\n", - " temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n", - " b, *_, device = *x.shape, x.device\n", - " e_t = self.model.apply_model(x, t, c)\n", - " if score_corrector is not None:\n", - " assert self.model.parameterization == \"eps\"\n", - " e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)\n", - "\n", - " alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas\n", - " alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev\n", - " sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas\n", - " sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas\n", - " # select parameters corresponding to the currently considered timestep\n", - " a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)\n", - " a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)\n", - " sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)\n", - " sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)\n", - "\n", - " # current prediction for x_0\n", - " pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()\n", - " if quantize_denoised:\n", - " pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)\n", - " # direction pointing to x_t\n", - " dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t\n", - " noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature\n", - " if noise_dropout > 0.:\n", - " noise = torch.nn.functional.dropout(noise, p=noise_dropout)\n", - " x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise\n", - " return x_prev, pred_x0\n", - "\n", - "\n", - "def download_models(mode):\n", - "\n", - " if mode == \"superresolution\":\n", - " # this is the small bsr light model\n", - " url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1'\n", - " url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1'\n", - "\n", - " path_conf = f'{model_path}/superres/project.yaml'\n", - " path_ckpt = f'{model_path}/superres/last.ckpt'\n", - "\n", - " download_url(url_conf, path_conf)\n", - " download_url(url_ckpt, path_ckpt)\n", - "\n", - " path_conf = path_conf + '/?dl=1' # fix it\n", - " path_ckpt = path_ckpt + '/?dl=1' # fix it\n", - " return path_conf, path_ckpt\n", - "\n", - " else:\n", - " raise NotImplementedError\n", - "\n", - "\n", - "def load_model_from_config(config, ckpt):\n", - " print(f\"Loading model from {ckpt}\")\n", - " pl_sd = torch.load(ckpt, map_location=\"cpu\")\n", - " global_step = pl_sd[\"global_step\"]\n", - " sd = pl_sd[\"state_dict\"]\n", - " model = instantiate_from_config(config.model)\n", - " m, u = model.load_state_dict(sd, strict=False)\n", - " model.cuda()\n", - " model.eval()\n", - " return {\"model\": model}, global_step\n", - "\n", - "\n", - "def get_model(mode):\n", - " path_conf, path_ckpt = download_models(mode)\n", - " config = OmegaConf.load(path_conf)\n", - " model, step = load_model_from_config(config, path_ckpt)\n", - " return model\n", - "\n", - "\n", - "def get_custom_cond(mode):\n", - " dest = \"data/example_conditioning\"\n", - "\n", - " if mode == \"superresolution\":\n", - " uploaded_img = files.upload()\n", - " filename = next(iter(uploaded_img))\n", - " name, filetype = filename.split(\".\") # todo assumes just one dot in name !\n", - " os.rename(f\"{filename}\", f\"{dest}/{mode}/custom_{name}.{filetype}\")\n", - "\n", - " elif mode == \"text_conditional\":\n", - " w = widgets.Text(value='A cake with cream!', disabled=True)\n", - " display.display(w)\n", - "\n", - " with open(f\"{dest}/{mode}/custom_{w.value[:20]}.txt\", 'w') as f:\n", - " f.write(w.value)\n", - "\n", - " elif mode == \"class_conditional\":\n", - " w = widgets.IntSlider(min=0, max=1000)\n", - " display.display(w)\n", - " with open(f\"{dest}/{mode}/custom.txt\", 'w') as f:\n", - " f.write(w.value)\n", - "\n", - " else:\n", - " raise NotImplementedError(f\"cond not implemented for mode{mode}\")\n", - "\n", - "\n", - "def get_cond_options(mode):\n", - " path = \"data/example_conditioning\"\n", - " path = os.path.join(path, mode)\n", - " onlyfiles = [f for f in sorted(os.listdir(path))]\n", - " return path, onlyfiles\n", - "\n", - "\n", - "def select_cond_path(mode):\n", - " path = \"data/example_conditioning\" # todo\n", - " path = os.path.join(path, mode)\n", - " onlyfiles = [f for f in sorted(os.listdir(path))]\n", - "\n", - " selected = widgets.RadioButtons(\n", - " options=onlyfiles,\n", - " description='Select conditioning:',\n", - " disabled=False\n", - " )\n", - " display.display(selected)\n", - " selected_path = os.path.join(path, selected.value)\n", - " return selected_path\n", - "\n", - "\n", - "def get_cond(mode, img):\n", - " example = dict()\n", - " if mode == \"superresolution\":\n", - " up_f = 4\n", - " # visualize_cond_img(selected_path)\n", - "\n", - " c = img\n", - " c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)\n", - " c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True)\n", - " c_up = rearrange(c_up, '1 c h w -> 1 h w c')\n", - " c = rearrange(c, '1 c h w -> 1 h w c')\n", - " c = 2. * c - 1.\n", - "\n", - " c = c.to(torch.device(\"cuda\"))\n", - " example[\"LR_image\"] = c\n", - " example[\"image\"] = c_up\n", - "\n", - " return example\n", - "\n", - "\n", - "def visualize_cond_img(path):\n", - " display.display(ipyimg(filename=path))\n", - "\n", - "\n", - "def sr_run(model, img, task, custom_steps, eta, resize_enabled=False, classifier_ckpt=None, global_step=None):\n", - " # global stride\n", - "\n", - " example = get_cond(task, img)\n", - "\n", - " save_intermediate_vid = False\n", - " n_runs = 1\n", - " masked = False\n", - " guider = None\n", - " ckwargs = None\n", - " mode = 'ddim'\n", - " ddim_use_x0_pred = False\n", - " temperature = 1.\n", - " eta = eta\n", - " make_progrow = True\n", - " custom_shape = None\n", - "\n", - " height, width = example[\"image\"].shape[1:3]\n", - " split_input = height >= 128 and width >= 128\n", - "\n", - " if split_input:\n", - " ks = 128\n", - " stride = 64\n", - " vqf = 4 #\n", - " model.split_input_params = {\"ks\": (ks, ks), \"stride\": (stride, stride),\n", - " \"vqf\": vqf,\n", - " \"patch_distributed_vq\": True,\n", - " \"tie_braker\": False,\n", - " \"clip_max_weight\": 0.5,\n", - " \"clip_min_weight\": 0.01,\n", - " \"clip_max_tie_weight\": 0.5,\n", - " \"clip_min_tie_weight\": 0.01}\n", - " else:\n", - " if hasattr(model, \"split_input_params\"):\n", - " delattr(model, \"split_input_params\")\n", - "\n", - " invert_mask = False\n", - "\n", - " x_T = None\n", - " for n in range(n_runs):\n", - " if custom_shape is not None:\n", - " x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)\n", - " x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0])\n", - "\n", - " logs = make_convolutional_sample(example, model,\n", - " mode=mode, custom_steps=custom_steps,\n", - " eta=eta, swap_mode=False , masked=masked,\n", - " invert_mask=invert_mask, quantize_x0=False,\n", - " custom_schedule=None, decode_interval=10,\n", - " resize_enabled=resize_enabled, custom_shape=custom_shape,\n", - " temperature=temperature, noise_dropout=0.,\n", - " corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid,\n", - " make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred\n", - " )\n", - " return logs\n", - "\n", - "\n", - "@torch.no_grad()\n", - "def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,\n", - " mask=None, x0=None, quantize_x0=False, img_callback=None,\n", - " temperature=1., noise_dropout=0., score_corrector=None,\n", - " corrector_kwargs=None, x_T=None, log_every_t=None\n", - " ):\n", - "\n", - " ddim = DDIMSampler(model)\n", - " bs = shape[0] # dont know where this comes from but wayne\n", - " shape = shape[1:] # cut batch dim\n", - " # print(f\"Sampling with eta = {eta}; steps: {steps}\")\n", - " samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,\n", - " normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,\n", - " mask=mask, x0=x0, temperature=temperature, verbose=False,\n", - " score_corrector=score_corrector,\n", - " corrector_kwargs=corrector_kwargs, x_T=x_T)\n", - "\n", - " return samples, intermediates\n", - "\n", - "\n", - "@torch.no_grad()\n", - "def make_convolutional_sample(batch, model, mode=\"vanilla\", custom_steps=None, eta=1.0, swap_mode=False, masked=False,\n", - " invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000,\n", - " resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,\n", - " corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False):\n", - " log = dict()\n", - "\n", - " z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,\n", - " return_first_stage_outputs=True,\n", - " force_c_encode=not (hasattr(model, 'split_input_params')\n", - " and model.cond_stage_key == 'coordinates_bbox'),\n", - " return_original_cond=True)\n", - "\n", - " log_every_t = 1 if save_intermediate_vid else None\n", - "\n", - " if custom_shape is not None:\n", - " z = torch.randn(custom_shape)\n", - " # print(f\"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}\")\n", - "\n", - " z0 = None\n", - "\n", - " log[\"input\"] = x\n", - " log[\"reconstruction\"] = xrec\n", - "\n", - " if ismap(xc):\n", - " log[\"original_conditioning\"] = model.to_rgb(xc)\n", - " if hasattr(model, 'cond_stage_key'):\n", - " log[model.cond_stage_key] = model.to_rgb(xc)\n", - "\n", - " else:\n", - " log[\"original_conditioning\"] = xc if xc is not None else torch.zeros_like(x)\n", - " if model.cond_stage_model:\n", - " log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)\n", - " if model.cond_stage_key =='class_label':\n", - " log[model.cond_stage_key] = xc[model.cond_stage_key]\n", - "\n", - " with model.ema_scope(\"Plotting\"):\n", - " t0 = time.time()\n", - " img_cb = None\n", - "\n", - " sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,\n", - " eta=eta,\n", - " quantize_x0=quantize_x0, img_callback=img_cb, mask=None, x0=z0,\n", - " temperature=temperature, noise_dropout=noise_dropout,\n", - " score_corrector=corrector, corrector_kwargs=corrector_kwargs,\n", - " x_T=x_T, log_every_t=log_every_t)\n", - " t1 = time.time()\n", - "\n", - " if ddim_use_x0_pred:\n", - " sample = intermediates['pred_x0'][-1]\n", - "\n", - " x_sample = model.decode_first_stage(sample)\n", - "\n", - " try:\n", - " x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)\n", - " log[\"sample_noquant\"] = x_sample_noquant\n", - " log[\"sample_diff\"] = torch.abs(x_sample_noquant - x_sample)\n", - " except:\n", - " pass\n", - "\n", - " log[\"sample\"] = x_sample\n", - " log[\"time\"] = t1 - t0\n", - "\n", - " return log\n", - "\n", - "sr_diffMode = 'superresolution'\n", - "sr_model = get_model('superresolution')\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "def do_superres(img, filepath):\n", - "\n", - " if args.sharpen_preset == 'Faster':\n", - " sr_diffusion_steps = \"25\" \n", - " sr_pre_downsample = '1/2' \n", - " if args.sharpen_preset == 'Fast':\n", - " sr_diffusion_steps = \"100\" \n", - " sr_pre_downsample = '1/2' \n", - " if args.sharpen_preset == 'Slow':\n", - " sr_diffusion_steps = \"25\" \n", - " sr_pre_downsample = 'None' \n", - " if args.sharpen_preset == 'Very Slow':\n", - " sr_diffusion_steps = \"100\" \n", - " sr_pre_downsample = 'None' \n", - "\n", - "\n", - " sr_post_downsample = 'Original Size'\n", - " sr_diffusion_steps = int(sr_diffusion_steps)\n", - " sr_eta = 1.0 \n", - " sr_downsample_method = 'Lanczos' \n", - "\n", - " gc.collect()\n", - " torch.cuda.empty_cache()\n", - "\n", - " im_og = img\n", - " width_og, height_og = im_og.size\n", - "\n", - " #Downsample Pre\n", - " if sr_pre_downsample == '1/2':\n", - " downsample_rate = 2\n", - " elif sr_pre_downsample == '1/4':\n", - " downsample_rate = 4\n", - " else:\n", - " downsample_rate = 1\n", - "\n", - " width_downsampled_pre = width_og//downsample_rate\n", - " height_downsampled_pre = height_og//downsample_rate\n", - "\n", - " if downsample_rate != 1:\n", - " # print(f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')\n", - " im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)\n", - " # im_og.save('/content/temp.png')\n", - " # filepath = '/content/temp.png'\n", - "\n", - " logs = sr_run(sr_model[\"model\"], im_og, sr_diffMode, sr_diffusion_steps, sr_eta)\n", - "\n", - " sample = logs[\"sample\"]\n", - " sample = sample.detach().cpu()\n", - " sample = torch.clamp(sample, -1., 1.)\n", - " sample = (sample + 1.) / 2. * 255\n", - " sample = sample.numpy().astype(np.uint8)\n", - " sample = np.transpose(sample, (0, 2, 3, 1))\n", - " a = Image.fromarray(sample[0])\n", - "\n", - " #Downsample Post\n", - " if sr_post_downsample == '1/2':\n", - " downsample_rate = 2\n", - " elif sr_post_downsample == '1/4':\n", - " downsample_rate = 4\n", - " else:\n", - " downsample_rate = 1\n", - "\n", - " width, height = a.size\n", - " width_downsampled_post = width//downsample_rate\n", - " height_downsampled_post = height//downsample_rate\n", - "\n", - " if sr_downsample_method == 'Lanczos':\n", - " aliasing = Image.LANCZOS\n", - " else:\n", - " aliasing = Image.NEAREST\n", - "\n", - " if downsample_rate != 1:\n", - " # print(f'Downsampling from [{width}, {height}] to [{width_downsampled_post}, {height_downsampled_post}]')\n", - " a = a.resize((width_downsampled_post, height_downsampled_post), aliasing)\n", - " elif sr_post_downsample == 'Original Size':\n", - " # print(f'Downsampling from [{width}, {height}] to Original Size [{width_og}, {height_og}]')\n", - " a = a.resize((width_og, height_og), aliasing)\n", - "\n", - " display.display(a)\n", - " a.save(filepath)\n", - " return\n", - " print(f'Processing finished!')\n" - ], - "outputs": [], - "execution_count": null - }, { "cell_type": "markdown", "metadata": { @@ -2240,8 +1674,7 @@ "use_secondary_model = True #@param {type: 'boolean'}\n", "diffusion_sampling_mode = 'ddim' #@param ['plms','ddim'] \n", "\n", - "timestep_respacing = '250' #@param ['25','50','100','150','250','500','1000','ddim25','ddim50', 'ddim75', 'ddim100','ddim150','ddim250','ddim500','ddim1000'] \n", - "diffusion_steps = 1000 #@param {type: 'number'}\n", + "\n", "use_checkpoint = True #@param {type: 'boolean'}\n", "ViTB32 = True #@param{type:\"boolean\"}\n", "ViTB16 = True #@param{type:\"boolean\"}\n", @@ -2333,9 +1766,9 @@ " model_config.update({\n", " 'attention_resolutions': '32, 16, 8',\n", " 'class_cond': False,\n", - " 'diffusion_steps': diffusion_steps,\n", + " 'diffusion_steps': 1000, #No need to edit this, it is taken care of later.\n", " 'rescale_timesteps': True,\n", - " 'timestep_respacing': timestep_respacing,\n", + " 'timestep_respacing': 250, #No need to edit this, it is taken care of later.\n", " 'image_size': 512,\n", " 'learn_sigma': True,\n", " 'noise_schedule': 'linear',\n", @@ -2351,9 +1784,9 @@ " model_config.update({\n", " 'attention_resolutions': '32, 16, 8',\n", " 'class_cond': False,\n", - " 'diffusion_steps': diffusion_steps,\n", + " 'diffusion_steps': 1000, #No need to edit this, it is taken care of later.\n", " 'rescale_timesteps': True,\n", - " 'timestep_respacing': timestep_respacing,\n", + " 'timestep_respacing': 250, #No need to edit this, it is taken care of later.\n", " 'image_size': 256,\n", " 'learn_sigma': True,\n", " 'noise_schedule': 'linear',\n", @@ -2398,24 +1831,10 @@ " SLIPB16model.load_state_dict(real_sd)\n", " SLIPB16model.requires_grad_(False).eval().to(device)\n", "\n", - " clip_models.append(SLIPB16model)\n", - "\n", - "if SLIPL16:\n", - " SLIPL16model = SLIP_VITL16(ssl_mlp_dim=4096, ssl_emb_dim=256)\n", - " if not os.path.exists(f'{model_path}/slip_large_100ep.pt'):\n", - " wget(\"https://dl.fbaipublicfiles.com/slip/slip_large_100ep.pt\", model_path)\n", - " sd = torch.load(f'{model_path}/slip_large_100ep.pt')\n", - " real_sd = {}\n", - " for k, v in sd['state_dict'].items():\n", - " real_sd['.'.join(k.split('.')[1:])] = v\n", - " del sd\n", - " SLIPL16model.load_state_dict(real_sd)\n", - " SLIPL16model.requires_grad_(False).eval().to(device)\n", - "\n", " clip_models.append(SLIPL16model)\n", "\n", "normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])\n", - "lpips_model = lpips.LPIPS(net='vgg').to(device)\n" + "lpips_model = lpips.LPIPS(net='vgg').to(device)" ], "outputs": [], "execution_count": null @@ -2470,7 +1889,7 @@ "\n", "#Make folder for batch\n", "batchFolder = f'{outDirPath}/{batch_name}'\n", - "createPath(batchFolder)\n" + "createPath(batchFolder)" ], "outputs": [], "execution_count": null @@ -2816,7 +2235,7 @@ " translation_z = float(translation_z)\n", " rotation_3d_x = float(rotation_3d_x)\n", " rotation_3d_y = float(rotation_3d_y)\n", - " rotation_3d_z = float(rotation_3d_z)\n" + " rotation_3d_z = float(rotation_3d_z)" ], "outputs": [], "execution_count": null @@ -2828,7 +2247,7 @@ }, "source": [ "### Extra Settings\n", - " Partial Saves, Diffusion Sharpening, Advanced Settings, Cutn Scheduling" + " Partial Saves, Advanced Settings, Cutn Scheduling" ] }, { @@ -2864,18 +2283,6 @@ "\n", " #@markdown ---\n", "\n", - "#@markdown ####**SuperRes Sharpening:**\n", - "#@markdown *Sharpen each image using latent-diffusion. Does not run in animation mode. `keep_unsharp` will save both versions.*\n", - "sharpen_preset = 'Off' #@param ['Off', 'Faster', 'Fast', 'Slow', 'Very Slow']\n", - "keep_unsharp = True #@param{type: 'boolean'}\n", - "\n", - "if sharpen_preset != 'Off' and keep_unsharp is True:\n", - " unsharpenFolder = f'{batchFolder}/unsharpened'\n", - " createPath(unsharpenFolder)\n", - "\n", - "\n", - " #@markdown ---\n", - "\n", "#@markdown ####**Advanced Settings:**\n", "#@markdown *There are a few extra advanced settings available if you double click this cell.*\n", "\n", @@ -2906,7 +2313,7 @@ "cut_overview = \"[12]*400+[4]*600\" #@param {type: 'string'} \n", "cut_innercut =\"[4]*400+[12]*600\"#@param {type: 'string'} \n", "cut_ic_pow = 1#@param {type: 'number'} \n", - "cut_icgray_p = \"[0.2]*400+[0]*600\"#@param {type: 'string'}\n" + "cut_icgray_p = \"[0.2]*400+[0]*600\"#@param {type: 'string'}" ], "outputs": [], "execution_count": null @@ -2934,7 +2341,7 @@ "\n", "image_prompts = {\n", " # 0:['ImagePromptsWorkButArentVeryGood.png:2',],\n", - "}\n" + "}" ], "outputs": [], "execution_count": null @@ -3050,8 +2457,6 @@ " 'init_image': init_image,\n", " 'init_scale': init_scale,\n", " 'skip_steps': skip_steps,\n", - " 'sharpen_preset': sharpen_preset,\n", - " 'keep_unsharp': keep_unsharp,\n", " 'side_x': side_x,\n", " 'side_y': side_y,\n", " 'timestep_respacing': timestep_respacing,\n", @@ -3134,7 +2539,7 @@ "finally:\n", " print('Seed used:', seed)\n", " gc.collect()\n", - " torch.cuda.empty_cache()\n" + " torch.cuda.empty_cache()" ], "outputs": [], "execution_count": null @@ -3228,7 +2633,7 @@ " # mp4 = open(filepath,'rb').read()\n", " # data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", " # display.HTML(f'')\n", - " \n" + " " ], "outputs": [], "execution_count": null diff --git a/README.md b/README.md index a793b8c..a35472b 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,16 @@ A frankensteinian amalgamation of notebooks, models and techniques for the gener #### v5 Update: Feb 20th 2022 - gandamu / Adam Letts * Added 3D animation mode. Uses weighted combination of AdaBins and MiDaS depth estimation models. Uses pytorch3d for 3D transforms on Colab and/or Linux. +#### v5.1 Update: Mar 30th 2022 - zippy / Chris Allen and gandamu / Adam Letts / MSFTserver aka HostsServer + +* Integrated Turbo+Smooth features from Disco Diffusion Turbo -- just the implementation, without its defaults. +* Implemented resume of turbo animations in such a way that it's now possible to resume from different batch folders and batch numbers. +* 3D rotation parameter units are now degrees (rather than radians) +* Corrected name collision in sampling_mode (now diffusion_sampling_mode for plms/ddim, and sampling_mode for 3D transform sampling) +* Added video_init_seed_continuity option to make init video animations more continuous +* Removed pytorch3d from needing to be compiled with a lite version specifically made for Disco Diffusion +* Remove Super Resolution + ## Notebook Provenance diff --git a/disco.py b/disco.py index 82c523e..c7001ab 100644 --- a/disco.py +++ b/disco.py @@ -1,9 +1,16 @@ # %% +# !! {"metadata": { +# !! "id": "view-in-github", +# !! "colab_type": "text" +# !! }} """ Open In Colab """ # %% +# !! {"metadata": { +# !! "id": "TitleTop" +# !! }} """ # Disco Diffusion v5.1 - Now with Turbo @@ -13,11 +20,17 @@ For issues, join the [Disco Diffusion Discord](https://discord.gg/msEZBy4HxA) or """ # %% +# !! {"metadata": { +# !! "id": "CreditsChTop" +# !! }} """ ### Credits & Changelog ⬇️ """ # %% +# !! {"metadata": { +# !! "id": "Credits" +# !! }} """ #### Credits @@ -45,11 +58,17 @@ Turbo feature by Chris Allen (https://twitter.com/zippy731) """ # %% +# !! {"metadata": { +# !! "id": "LicenseTop" +# !! }} """ #### License """ # %% +# !! {"metadata": { +# !! "id": "License" +# !! }} """ Licensed under the MIT License @@ -125,11 +144,18 @@ THE SOFTWARE. """ # %% +# !! {"metadata": { +# !! "id": "ChangelogTop" +# !! }} """ #### Changelog """ # %% +# !! {"metadata": { +# !! "cellView": "form", +# !! "id": "Changelog" +# !! }} #@title <- View Changelog skip_for_run_all = True #@param {type: 'boolean'} @@ -204,7 +230,7 @@ if skip_for_run_all == False: IPython magic commands replaced by Python code - v5.1 Update: Mar 30th 2022 - zippy / Chris Allen and gandamu / Adam Letts + v5.1 Update: Mar 30th 2022 - zippy / Chris Allen and gandamu / Adam Letts / MSFTserver aka HostsServer Integrated Turbo+Smooth features from Disco Diffusion Turbo -- just the implementation, without its defaults. @@ -216,16 +242,26 @@ if skip_for_run_all == False: Added video_init_seed_continuity option to make init video animations more continuous + Removed pytorch3d from needing to be compiled with a lite version specifically made for Disco Diffusion + + Remove Super Resolution + ''' ) # %% +# !! {"metadata": { +# !! "id": "TutorialTop" +# !! }} """ # Tutorial """ # %% +# !! {"metadata": { +# !! "id": "DiffusionSet" +# !! }} """ **Diffusion settings (Defaults are heavily outdated)** --- @@ -275,11 +311,18 @@ Setting | Description | Default """ # %% +# !! {"metadata": { +# !! "id": "SetupTop" +# !! }} """ # 1. Set Up """ # %% +# !! {"metadata": { +# !! "cellView": "form", +# !! "id": "CheckGPU" +# !! }} #@title 1.1 Check GPU Status import subprocess simple_nvidia_smi_display = False#@param {type:"boolean"} @@ -295,6 +338,10 @@ else: print(nvidiasmi_ecc_note) # %% +# !! {"metadata": { +# !! "cellView": "form", +# !! "id": "PrepFolders" +# !! }} #@title 1.2 Prepare Folders import subprocess import sys @@ -363,6 +410,10 @@ else: # createPath(libraries) # %% +# !! {"metadata": { +# !! "cellView": "form", +# !! "id": "InstallDeps" +# !! }} #@title ### 1.3 Install and import dependencies import pathlib, shutil @@ -453,32 +504,7 @@ import matplotlib.pyplot as plt import random from ipywidgets import Output import hashlib - -#SuperRes -if is_colab: - gitclone("https://github.com/CompVis/latent-diffusion.git") - gitclone("https://github.com/CompVis/taming-transformers") - pipie("./taming-transformers") - pipi("ipywidgets omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops wandb") - -#SuperRes -import ipywidgets as widgets -import os -sys.path.append(".") -sys.path.append('./taming-transformers') -from taming.models import vqgan # checking correct import from taming -from torchvision.datasets.utils import download_url - -if is_colab: - os.chdir('/content/latent-diffusion') -else: - #os.chdir('latent-diffusion') - sys.path.append('latent-diffusion') from functools import partial -from ldm.util import instantiate_from_config -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like -# from ldm.models.diffusion.ddim import DDIMSampler -from ldm.util import ismap if is_colab: os.chdir('/content') from google.colab import files @@ -515,6 +541,10 @@ if torch.cuda.get_device_capability(DEVICE) == (8,0): ## A100 fix thanks to Emad torch.backends.cudnn.enabled = False # %% +# !! {"metadata": { +# !! "cellView": "form", +# !! "id": "DefMidasFns" +# !! }} #@title ### 1.4 Define Midas functions from midas.dpt_depth import DPTDepthModel @@ -618,6 +648,10 @@ def init_midas_depth_model(midas_model_type="dpt_large", optimize=True): return midas_model, midas_transform, net_w, net_h, resize_mode, normalization # %% +# !! {"metadata": { +# !! "cellView": "form", +# !! "id": "DefFns" +# !! }} #@title 1.5 Define necessary functions # https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869 @@ -1263,7 +1297,6 @@ def do_run(): # with run_display: # display.clear_output(wait=True) - imgToSharpen = None for j, sample in enumerate(samples): cur_t -= 1 intermediateStep = False @@ -1312,31 +1345,20 @@ def do_run(): save_settings() if args.animation_mode != "None": image.save('prevFrame.png') - if args.sharpen_preset != "Off" and animation_mode == "None": - imgToSharpen = image - if args.keep_unsharp is True: - image.save(f'{unsharpenFolder}/{filename}') - else: - image.save(f'{batchFolder}/{filename}') - if args.animation_mode == "3D": - # If turbo, save a blended image - if turbo_mode: - # Mix new image with prevFrameScaled - blend_factor = (1)/int(turbo_steps) - newFrame = cv2.imread('prevFrame.png') # This is already updated.. - prev_frame_warped = cv2.imread('prevFrameScaled.png') - blendedImage = cv2.addWeighted(newFrame, blend_factor, prev_frame_warped, (1-blend_factor), 0.0) - cv2.imwrite(f'{batchFolder}/{filename}',blendedImage) - else: - image.save(f'{batchFolder}/{filename}') + image.save(f'{batchFolder}/{filename}') + if args.animation_mode == "3D": + # If turbo, save a blended image + if turbo_mode: + # Mix new image with prevFrameScaled + blend_factor = (1)/int(turbo_steps) + newFrame = cv2.imread('prevFrame.png') # This is already updated.. + prev_frame_warped = cv2.imread('prevFrameScaled.png') + blendedImage = cv2.addWeighted(newFrame, blend_factor, prev_frame_warped, (1-blend_factor), 0.0) + cv2.imwrite(f'{batchFolder}/{filename}',blendedImage) + else: + image.save(f'{batchFolder}/{filename}') # if frame_num != args.max_frames-1: # display.clear_output() - - with image_display: - if args.sharpen_preset != "Off" and animation_mode == "None": - print('Starting Diffusion Sharpening...') - do_superres(imgToSharpen, f'{batchFolder}/{filename}') - display.clear_output() plt.plot(np.array(loss_values), 'r') @@ -1418,6 +1440,10 @@ def save_settings(): json.dump(setting_list, f, ensure_ascii=False, indent=4) # %% +# !! {"metadata": { +# !! "cellView": "form", +# !! "id": "DefSecModel" +# !! }} #@title 1.6 Define the secondary diffusion model def append_dims(x, n): @@ -1582,537 +1608,19 @@ class SecondaryDiffusionImageNet2(nn.Module): eps = input * sigmas + v * alphas return DiffusionOutput(v, pred, eps) -# %% -#@title 1.7 SuperRes Define -class DDIMSampler(object): - def __init__(self, model, schedule="linear", **kwargs): - super().__init__() - self.model = model - self.ddpm_num_timesteps = model.num_timesteps - self.schedule = schedule - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) - setattr(self, name, attr) - - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) - alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) - - # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) - - @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - **kwargs - ): - if conditioning is not None: - if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - # print(f'Data shape for DDIM sampling is {size}, eta {eta}') - - samples, intermediates = self.ddim_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t - ) - return samples, intermediates - - @torch.no_grad() - def ddim_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): - device = self.model.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - if timesteps is None: - timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps - elif timesteps is not None and not ddim_use_original_steps: - subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 - timesteps = self.ddim_timesteps[:subset_end] - - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - print(f"Running DDIM Sharpening with {total_steps} timesteps") - - iterator = tqdm(time_range, desc='DDIM Sharpening', total=total_steps) - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - - if mask is not None: - assert x0 is not None - img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? - img = img_orig * mask + (1. - mask) * img - - outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs) - img, pred_x0 = outs - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - - return img, intermediates - - @torch.no_grad() - def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): - b, *_, device = *x.shape, x.device - e_t = self.model.apply_model(x, t, c) - if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) - - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - -def download_models(mode): - - if mode == "superresolution": - # this is the small bsr light model - url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1' - url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1' - - path_conf = f'{model_path}/superres/project.yaml' - path_ckpt = f'{model_path}/superres/last.ckpt' - - download_url(url_conf, path_conf) - download_url(url_ckpt, path_ckpt) - - path_conf = path_conf + '/?dl=1' # fix it - path_ckpt = path_ckpt + '/?dl=1' # fix it - return path_conf, path_ckpt - - else: - raise NotImplementedError - - -def load_model_from_config(config, ckpt): - print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location="cpu") - global_step = pl_sd["global_step"] - sd = pl_sd["state_dict"] - model = instantiate_from_config(config.model) - m, u = model.load_state_dict(sd, strict=False) - model.cuda() - model.eval() - return {"model": model}, global_step - - -def get_model(mode): - path_conf, path_ckpt = download_models(mode) - config = OmegaConf.load(path_conf) - model, step = load_model_from_config(config, path_ckpt) - return model - - -def get_custom_cond(mode): - dest = "data/example_conditioning" - - if mode == "superresolution": - uploaded_img = files.upload() - filename = next(iter(uploaded_img)) - name, filetype = filename.split(".") # todo assumes just one dot in name ! - os.rename(f"{filename}", f"{dest}/{mode}/custom_{name}.{filetype}") - - elif mode == "text_conditional": - w = widgets.Text(value='A cake with cream!', disabled=True) - display.display(w) - - with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", 'w') as f: - f.write(w.value) - - elif mode == "class_conditional": - w = widgets.IntSlider(min=0, max=1000) - display.display(w) - with open(f"{dest}/{mode}/custom.txt", 'w') as f: - f.write(w.value) - - else: - raise NotImplementedError(f"cond not implemented for mode{mode}") - - -def get_cond_options(mode): - path = "data/example_conditioning" - path = os.path.join(path, mode) - onlyfiles = [f for f in sorted(os.listdir(path))] - return path, onlyfiles - - -def select_cond_path(mode): - path = "data/example_conditioning" # todo - path = os.path.join(path, mode) - onlyfiles = [f for f in sorted(os.listdir(path))] - - selected = widgets.RadioButtons( - options=onlyfiles, - description='Select conditioning:', - disabled=False - ) - display.display(selected) - selected_path = os.path.join(path, selected.value) - return selected_path - - -def get_cond(mode, img): - example = dict() - if mode == "superresolution": - up_f = 4 - # visualize_cond_img(selected_path) - - c = img - c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) - c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True) - c_up = rearrange(c_up, '1 c h w -> 1 h w c') - c = rearrange(c, '1 c h w -> 1 h w c') - c = 2. * c - 1. - - c = c.to(torch.device("cuda")) - example["LR_image"] = c - example["image"] = c_up - - return example - - -def visualize_cond_img(path): - display.display(ipyimg(filename=path)) - - -def sr_run(model, img, task, custom_steps, eta, resize_enabled=False, classifier_ckpt=None, global_step=None): - # global stride - - example = get_cond(task, img) - - save_intermediate_vid = False - n_runs = 1 - masked = False - guider = None - ckwargs = None - mode = 'ddim' - ddim_use_x0_pred = False - temperature = 1. - eta = eta - make_progrow = True - custom_shape = None - - height, width = example["image"].shape[1:3] - split_input = height >= 128 and width >= 128 - - if split_input: - ks = 128 - stride = 64 - vqf = 4 # - model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride), - "vqf": vqf, - "patch_distributed_vq": True, - "tie_braker": False, - "clip_max_weight": 0.5, - "clip_min_weight": 0.01, - "clip_max_tie_weight": 0.5, - "clip_min_tie_weight": 0.01} - else: - if hasattr(model, "split_input_params"): - delattr(model, "split_input_params") - - invert_mask = False - - x_T = None - for n in range(n_runs): - if custom_shape is not None: - x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) - x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0]) - - logs = make_convolutional_sample(example, model, - mode=mode, custom_steps=custom_steps, - eta=eta, swap_mode=False , masked=masked, - invert_mask=invert_mask, quantize_x0=False, - custom_schedule=None, decode_interval=10, - resize_enabled=resize_enabled, custom_shape=custom_shape, - temperature=temperature, noise_dropout=0., - corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid, - make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred - ) - return logs - - -@torch.no_grad() -def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None, - mask=None, x0=None, quantize_x0=False, img_callback=None, - temperature=1., noise_dropout=0., score_corrector=None, - corrector_kwargs=None, x_T=None, log_every_t=None - ): - - ddim = DDIMSampler(model) - bs = shape[0] # dont know where this comes from but wayne - shape = shape[1:] # cut batch dim - # print(f"Sampling with eta = {eta}; steps: {steps}") - samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback, - normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta, - mask=mask, x0=x0, temperature=temperature, verbose=False, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, x_T=x_T) - - return samples, intermediates - - -@torch.no_grad() -def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, eta=1.0, swap_mode=False, masked=False, - invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000, - resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, - corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False): - log = dict() - - z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=not (hasattr(model, 'split_input_params') - and model.cond_stage_key == 'coordinates_bbox'), - return_original_cond=True) - - log_every_t = 1 if save_intermediate_vid else None - - if custom_shape is not None: - z = torch.randn(custom_shape) - # print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}") - - z0 = None - - log["input"] = x - log["reconstruction"] = xrec - - if ismap(xc): - log["original_conditioning"] = model.to_rgb(xc) - if hasattr(model, 'cond_stage_key'): - log[model.cond_stage_key] = model.to_rgb(xc) - - else: - log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x) - if model.cond_stage_model: - log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x) - if model.cond_stage_key =='class_label': - log[model.cond_stage_key] = xc[model.cond_stage_key] - - with model.ema_scope("Plotting"): - t0 = time.time() - img_cb = None - - sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape, - eta=eta, - quantize_x0=quantize_x0, img_callback=img_cb, mask=None, x0=z0, - temperature=temperature, noise_dropout=noise_dropout, - score_corrector=corrector, corrector_kwargs=corrector_kwargs, - x_T=x_T, log_every_t=log_every_t) - t1 = time.time() - - if ddim_use_x0_pred: - sample = intermediates['pred_x0'][-1] - - x_sample = model.decode_first_stage(sample) - - try: - x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) - log["sample_noquant"] = x_sample_noquant - log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) - except: - pass - - log["sample"] = x_sample - log["time"] = t1 - t0 - - return log - -sr_diffMode = 'superresolution' -sr_model = get_model('superresolution') - - - - - - -def do_superres(img, filepath): - - if args.sharpen_preset == 'Faster': - sr_diffusion_steps = "25" - sr_pre_downsample = '1/2' - if args.sharpen_preset == 'Fast': - sr_diffusion_steps = "100" - sr_pre_downsample = '1/2' - if args.sharpen_preset == 'Slow': - sr_diffusion_steps = "25" - sr_pre_downsample = 'None' - if args.sharpen_preset == 'Very Slow': - sr_diffusion_steps = "100" - sr_pre_downsample = 'None' - - - sr_post_downsample = 'Original Size' - sr_diffusion_steps = int(sr_diffusion_steps) - sr_eta = 1.0 - sr_downsample_method = 'Lanczos' - - gc.collect() - torch.cuda.empty_cache() - - im_og = img - width_og, height_og = im_og.size - - #Downsample Pre - if sr_pre_downsample == '1/2': - downsample_rate = 2 - elif sr_pre_downsample == '1/4': - downsample_rate = 4 - else: - downsample_rate = 1 - - width_downsampled_pre = width_og//downsample_rate - height_downsampled_pre = height_og//downsample_rate - - if downsample_rate != 1: - # print(f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]') - im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) - # im_og.save('/content/temp.png') - # filepath = '/content/temp.png' - - logs = sr_run(sr_model["model"], im_og, sr_diffMode, sr_diffusion_steps, sr_eta) - - sample = logs["sample"] - sample = sample.detach().cpu() - sample = torch.clamp(sample, -1., 1.) - sample = (sample + 1.) / 2. * 255 - sample = sample.numpy().astype(np.uint8) - sample = np.transpose(sample, (0, 2, 3, 1)) - a = Image.fromarray(sample[0]) - - #Downsample Post - if sr_post_downsample == '1/2': - downsample_rate = 2 - elif sr_post_downsample == '1/4': - downsample_rate = 4 - else: - downsample_rate = 1 - - width, height = a.size - width_downsampled_post = width//downsample_rate - height_downsampled_post = height//downsample_rate - - if sr_downsample_method == 'Lanczos': - aliasing = Image.LANCZOS - else: - aliasing = Image.NEAREST - - if downsample_rate != 1: - # print(f'Downsampling from [{width}, {height}] to [{width_downsampled_post}, {height_downsampled_post}]') - a = a.resize((width_downsampled_post, height_downsampled_post), aliasing) - elif sr_post_downsample == 'Original Size': - # print(f'Downsampling from [{width}, {height}] to Original Size [{width_og}, {height_og}]') - a = a.resize((width_og, height_og), aliasing) - - display.display(a) - a.save(filepath) - return - print(f'Processing finished!') - # %% +# !! {"metadata": { +# !! "id": "DiffClipSetTop" +# !! }} """ # 2. Diffusion and CLIP model settings """ # %% +# !! {"metadata": { +# !! "id": "ModelSettings" +# !! }} #@markdown ####**Models Settings:** diffusion_model = "512x512_diffusion_uncond_finetune_008100" #@param ["256x256_diffusion_uncond", "512x512_diffusion_uncond_finetune_008100"] use_secondary_model = True #@param {type: 'boolean'} @@ -2275,20 +1783,6 @@ if SLIPB16: SLIPB16model.load_state_dict(real_sd) SLIPB16model.requires_grad_(False).eval().to(device) - clip_models.append(SLIPB16model) - -if SLIPL16: - SLIPL16model = SLIP_VITL16(ssl_mlp_dim=4096, ssl_emb_dim=256) - if not os.path.exists(f'{model_path}/slip_large_100ep.pt'): - wget("https://dl.fbaipublicfiles.com/slip/slip_large_100ep.pt", model_path) - sd = torch.load(f'{model_path}/slip_large_100ep.pt') - real_sd = {} - for k, v in sd['state_dict'].items(): - real_sd['.'.join(k.split('.')[1:])] = v - del sd - SLIPL16model.load_state_dict(real_sd) - SLIPL16model.requires_grad_(False).eval().to(device) - clip_models.append(SLIPL16model) normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) @@ -2296,11 +1790,17 @@ lpips_model = lpips.LPIPS(net='vgg').to(device) # %% +# !! {"metadata": { +# !! "id": "SettingsTop" +# !! }} """ # 3. Settings """ # %% +# !! {"metadata": { +# !! "id": "BasicSettings" +# !! }} #@markdown ####**Basic Settings:** batch_name = 'TimeToDisco' #@param{type: 'string'} steps = 250 #@param [25,50,100,150,250,500,1000]{type: 'raw', allow-input: true} @@ -2340,11 +1840,17 @@ createPath(batchFolder) # %% +# !! {"metadata": { +# !! "id": "AnimSetTop" +# !! }} """ ### Animation Settings """ # %% +# !! {"metadata": { +# !! "id": "AnimSettings" +# !! }} #@markdown ####**Animation Mode:** animation_mode = 'None' #@param ['None', '2D', '3D', 'Video Input'] {type:'string'} #@markdown *For animation, you probably want to turn `cutn_batches` to 1 to make it quicker.* @@ -2675,12 +2181,18 @@ else: # %% +# !! {"metadata": { +# !! "id": "ExtraSetTop" +# !! }} """ ### Extra Settings - Partial Saves, Diffusion Sharpening, Advanced Settings, Cutn Scheduling + Partial Saves, Advanced Settings, Cutn Scheduling """ # %% +# !! {"metadata": { +# !! "id": "ExtraSettings" +# !! }} #@markdown ####**Saving:** intermediate_saves = 0#@param{type: 'raw'} @@ -2706,18 +2218,6 @@ if intermediate_saves and intermediates_in_subfolder is True: partialFolder = f'{batchFolder}/partials' createPath(partialFolder) - #@markdown --- - -#@markdown ####**SuperRes Sharpening:** -#@markdown *Sharpen each image using latent-diffusion. Does not run in animation mode. `keep_unsharp` will save both versions.* -sharpen_preset = 'Off' #@param ['Off', 'Faster', 'Fast', 'Slow', 'Very Slow'] -keep_unsharp = True #@param{type: 'boolean'} - -if sharpen_preset != 'Off' and keep_unsharp is True: - unsharpenFolder = f'{batchFolder}/unsharpened' - createPath(unsharpenFolder) - - #@markdown --- #@markdown ####**Advanced Settings:** @@ -2754,12 +2254,18 @@ cut_icgray_p = "[0.2]*400+[0]*600"#@param {type: 'string'} # %% +# !! {"metadata": { +# !! "id": "PromptsTop" +# !! }} """ ### Prompts `animation_mode: None` will only use the first set. `animation_mode: 2D / Video` will run through them per the set frames and hold on the last one. """ # %% +# !! {"metadata": { +# !! "id": "Prompts" +# !! }} text_prompts = { 0: ["A beautiful painting of a singular lighthouse, shining its light across a tumultuous sea of blood by greg rutkowski and thomas kinkade, Trending on artstation.", "yellow color scheme"], 100: ["This set of prompts start at frame 100","This prompt has weight five:5"], @@ -2771,11 +2277,17 @@ image_prompts = { # %% +# !! {"metadata": { +# !! "id": "DiffuseTop" +# !! }} """ # 4. Diffuse! """ # %% +# !! {"metadata": { +# !! "id": "DoTheRun" +# !! }} #@title Do the Run! #@markdown `n_batches` ignored with animation modes. display_rate = 50 #@param{type: 'number'} @@ -2872,8 +2384,6 @@ args = { 'init_image': init_image, 'init_scale': init_scale, 'skip_steps': skip_steps, - 'sharpen_preset': sharpen_preset, - 'keep_unsharp': keep_unsharp, 'side_x': side_x, 'side_y': side_y, 'timestep_respacing': timestep_respacing, @@ -2960,11 +2470,17 @@ finally: # %% +# !! {"metadata": { +# !! "id": "CreateVidTop" +# !! }} """ # 5. Create the video """ # %% +# !! {"metadata": { +# !! "id": "CreateVid" +# !! }} # @title ### **Create video** #@markdown Video file will save in the same folder as your images.