You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1135 lines
51 KiB
1135 lines
51 KiB
3 years ago
|
{
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 0,
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"name": "QoL MP Diffusion v2 [w/ Secondary Model v2].ipynb",
|
||
|
"private_outputs": true,
|
||
|
"provenance": [],
|
||
|
"collapsed_sections": [
|
||
|
"XTu6AjLyFQUq"
|
||
|
],
|
||
|
"machine_shape": "hm"
|
||
|
},
|
||
|
"kernelspec": {
|
||
|
"name": "python3",
|
||
|
"display_name": "Python 3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"name": "python"
|
||
|
},
|
||
|
"accelerator": "GPU"
|
||
|
},
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "1YwMUyt9LHG1"
|
||
|
},
|
||
|
"source": [
|
||
|
"# Generates images from text prompts with CLIP guided diffusion.\n",
|
||
|
"\n",
|
||
|
"By Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). It uses either OpenAI's 256x256 unconditional ImageNet or Katherine Crowson's fine-tuned 512x512 diffusion model (https://github.com/openai/guided-diffusion), together with CLIP (https://github.com/openai/CLIP) to connect text prompts with images.\n",
|
||
|
"\n",
|
||
|
"Modified by Daniel Russell (https://github.com/russelldc, https://twitter.com/danielrussruss) to include (hopefully) optimal params for quick generations in 15-100 timesteps rather than 1000, as well as more robust augmentations.\n",
|
||
|
"\n",
|
||
|
"**Update**: Sep 19th 2021\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"Further improvements from Dango233 and nsheppard helped improve the quality of diffusion in general, and especially so for shorter runs like this notebook aims to achieve.\n",
|
||
|
"\n",
|
||
|
"Katherine's original notebook can be found here:\n",
|
||
|
"https://colab.research.google.com/drive/1QBsaDAZv8np29FPbvjffbE1eytoJcsgA\n",
|
||
|
"\n",
|
||
|
"Vark added code to load in multiple Clip models at once, which all prompts are evaluated against, which may greatly improve accuracy.\n",
|
||
|
"\n",
|
||
|
"--\n",
|
||
|
"\n",
|
||
|
"I, Somnai (https://twitter.com/Somnai_dreams), have made the following QoL improvements and assorted implementations:\n",
|
||
|
"\n",
|
||
|
"**Update**: Oct 29th 2021\n",
|
||
|
"\n",
|
||
|
"QoL improvements added by Somnai (@somnai_dreams), including user friendly UI, settings+prompt saving and improved google drive folder organization.\n",
|
||
|
"\n",
|
||
|
"**Update**: Nov 13th 2021\n",
|
||
|
"\n",
|
||
|
"Now includes sizing options, intermediate saves and fixed image prompts and perlin inits. unexposed batch option since it doesn't work\n",
|
||
|
"\n",
|
||
|
"**Update**: Nov 22nd 2021\n",
|
||
|
"\n",
|
||
|
"Initial addition of Katherine Crowson's Secondary Model Method (https://colab.research.google.com/drive/1mpkrhOjoyzPeSWy2r7T8EYRaU7amYOOi#scrollTo=X5gODNAMEUCR)\n",
|
||
|
"\n",
|
||
|
"Noticed settings were saving with the wrong name so corrected it. Let me know if you preferred the old scheme."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "XTu6AjLyFQUq"
|
||
|
},
|
||
|
"source": [
|
||
|
"#Tutorial"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "YR806W0wi3He"
|
||
|
},
|
||
|
"source": [
|
||
|
"**Diffusion settings**\n",
|
||
|
"---\n",
|
||
|
"\n",
|
||
|
"Setting | Description | Default\n",
|
||
|
"--- | --- | ---\n",
|
||
|
"**Your vision:**\n",
|
||
|
"`text_prompts` | A description of what you'd like the machine to generate. Think of it like writing the caption below your image on a website. | N/A\n",
|
||
|
"`image_prompts` | Think of these images more as a description of their contents. | N/A\n",
|
||
|
"**Image quality:**\n",
|
||
|
"`clip_guidance_scale` | Controls how much the image should look like the prompt. | 1000\n",
|
||
|
"`tv_scale` | Controls the smoothness of the final output. | 150\n",
|
||
|
"`range_scale` | Controls how far out of range RGB values are allowed to be. | 150\n",
|
||
|
"`sat_scale` | Controls how much saturation is allowed. From nshepperd's JAX notebook. | 0\n",
|
||
|
"`cutn` | Controls how many crops to take from the image. | 16\n",
|
||
|
"`cutn_batches` | Accumulate CLIP gradient from multiple batches of cuts | 2\n",
|
||
|
"**Init settings:**\n",
|
||
|
"`init_image` | URL or local path | None\n",
|
||
|
"`init_scale` | This enhances the effect of the init image, a good value is 1000 | 0\n",
|
||
|
"`skip_timesteps` | Controls the starting point along the diffusion timesteps | 0\n",
|
||
|
"`perlin_init` | Option to start with random perlin noise | False\n",
|
||
|
"`perlin_mode` | ('gray', 'color') | 'mixed'\n",
|
||
|
"**Advanced:**\n",
|
||
|
"`skip_augs` |Controls whether to skip torchvision augmentations | False\n",
|
||
|
"`randomize_class` |Controls whether the imagenet class is randomly changed each iteration | True\n",
|
||
|
"`clip_denoised` |Determines whether CLIP discriminates a noisy or denoised image | False\n",
|
||
|
"`clamp_grad` |Experimental: Using adaptive clip grad in the cond_fn | True\n",
|
||
|
"`seed` | Choose a random seed and print it at end of run for reproduction | random_seed\n",
|
||
|
"`fuzzy_prompt` | Controls whether to add multiple noisy prompts to the prompt losses | False\n",
|
||
|
"`rand_mag` |Controls the magnitude of the random noise | 0.1\n",
|
||
|
"`eta` | DDIM hyperparameter | 0.5\n",
|
||
|
"\n",
|
||
|
"..\n",
|
||
|
"\n",
|
||
|
"**Model settings**\n",
|
||
|
"---\n",
|
||
|
"\n",
|
||
|
"Setting | Description | Default\n",
|
||
|
"--- | --- | ---\n",
|
||
|
"**Diffusion:**\n",
|
||
|
"`timestep_respacing` | Modify this value to decrease the number of timesteps. | ddim100\n",
|
||
|
"`diffusion_steps` || 1000\n",
|
||
|
"**Diffusion:**\n",
|
||
|
"`clip_models` | Models of CLIP to load. Typically the more, the better but they all come at a hefty VRAM cost. | ViT-B/32, ViT-B/16, RN50x4"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "_9Eg9Kf5FlfK"
|
||
|
},
|
||
|
"source": [
|
||
|
"# 1. Pre Set Up"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "qZ3rNuAWAewx",
|
||
|
"cellView": "form"
|
||
|
},
|
||
|
"source": [
|
||
|
"#@title 1.1 Check GPU Status\n",
|
||
|
"!nvidia-smi"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "yZsjzwS0YGo6",
|
||
|
"cellView": "form"
|
||
|
},
|
||
|
"source": [
|
||
|
"from google.colab import drive\n",
|
||
|
"#@title 1.2 Prepare Folders\n",
|
||
|
"#@markdown If you connect your Google Drive, you can save the final image of each run on your drive.\n",
|
||
|
"\n",
|
||
|
"google_drive = True #@param {type:\"boolean\"}\n",
|
||
|
"\n",
|
||
|
"#@markdown Click here if you'd like to save the diffusion model checkpoint file to (and/or load from) your Google Drive:\n",
|
||
|
"yes_please = True #@param {type:\"boolean\"}\n",
|
||
|
"\n",
|
||
|
"if google_drive is True:\n",
|
||
|
" drive.mount('/content/drive')\n",
|
||
|
" root_path = '/content/drive/MyDrive/AI/MP_Diffusion'\n",
|
||
|
"else:\n",
|
||
|
" root_path = '/content'\n",
|
||
|
"\n",
|
||
|
"import os\n",
|
||
|
"from os import path\n",
|
||
|
"#Simple create paths taken with modifications from Datamosh's Batch VQGAN+CLIP notebook\n",
|
||
|
"def createPath(filepath):\n",
|
||
|
" if path.exists(filepath) == False:\n",
|
||
|
" os.makedirs(filepath)\n",
|
||
|
" print(f'Made {filepath}')\n",
|
||
|
" else:\n",
|
||
|
" print(f'filepath {filepath} exists.')\n",
|
||
|
"\n",
|
||
|
"initDirPath = f'{root_path}/init_images'\n",
|
||
|
"createPath(initDirPath)\n",
|
||
|
"outDirPath = f'{root_path}/images_out'\n",
|
||
|
"createPath(outDirPath)\n",
|
||
|
"\n",
|
||
|
"if google_drive and not yes_please or not google_drive:\n",
|
||
|
" model_path = '/content/models'\n",
|
||
|
" createPath(model_path)\n",
|
||
|
"if google_drive and yes_please:\n",
|
||
|
" model_path = f'{root_path}/models'\n",
|
||
|
" createPath(model_path)\n",
|
||
|
"# libraries = f'{root_path}/libraries'\n",
|
||
|
"# createPath(libraries)\n",
|
||
|
"\n"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "otQKpqkGrF2r"
|
||
|
},
|
||
|
"source": [
|
||
|
"#2. Install\n",
|
||
|
"\n",
|
||
|
"Run this once at the start of your session and after a restart."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "JmbrcrhpBPC6",
|
||
|
"cellView": "form"
|
||
|
},
|
||
|
"source": [
|
||
|
"#@title ### 2.1 Install and import dependencies\n",
|
||
|
"\n",
|
||
|
"if google_drive is not True:\n",
|
||
|
" root_path = f'/content'\n",
|
||
|
" model_path = '/content/' \n",
|
||
|
"\n",
|
||
|
"!git clone https://github.com/openai/CLIP\n",
|
||
|
"!git clone https://github.com/crowsonkb/guided-diffusion\n",
|
||
|
"!pip install -e ./CLIP\n",
|
||
|
"!pip install -e ./guided-diffusion\n",
|
||
|
"!pip install lpips datetime\n",
|
||
|
"\n",
|
||
|
"from dataclasses import dataclass\n",
|
||
|
"from functools import partial\n",
|
||
|
"import gc\n",
|
||
|
"import io\n",
|
||
|
"import math\n",
|
||
|
"import sys\n",
|
||
|
"from IPython import display\n",
|
||
|
"import lpips\n",
|
||
|
"from PIL import Image, ImageOps\n",
|
||
|
"import requests\n",
|
||
|
"from glob import glob\n",
|
||
|
"import json\n",
|
||
|
"import torch\n",
|
||
|
"from torch import nn\n",
|
||
|
"from torch.nn import functional as F\n",
|
||
|
"import torchvision.transforms as T\n",
|
||
|
"import torchvision.transforms.functional as TF\n",
|
||
|
"from tqdm.notebook import tqdm\n",
|
||
|
"sys.path.append('./CLIP')\n",
|
||
|
"sys.path.append('./guided-diffusion')\n",
|
||
|
"import clip\n",
|
||
|
"from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults\n",
|
||
|
"from datetime import datetime\n",
|
||
|
"import numpy as np\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"import random\n",
|
||
|
"\n",
|
||
|
"import torch\n",
|
||
|
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
|
||
|
"print('Using device:', device)\n",
|
||
|
"\n",
|
||
|
"if torch.cuda.get_device_capability(device) == (8,0): ## A100 fix thanks to Emad\n",
|
||
|
" print('Disabling CUDNN for A100 gpu', file=sys.stderr)\n",
|
||
|
" torch.backends.cudnn.enabled = False"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "FpZczxnOnPIU",
|
||
|
"cellView": "form"
|
||
|
},
|
||
|
"source": [
|
||
|
"#@title 2.2 Define necessary functions\n",
|
||
|
"\n",
|
||
|
"# https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def interp(t):\n",
|
||
|
" return 3 * t**2 - 2 * t ** 3\n",
|
||
|
"\n",
|
||
|
"def perlin(width, height, scale=10, device=None):\n",
|
||
|
" gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device)\n",
|
||
|
" xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device)\n",
|
||
|
" ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device)\n",
|
||
|
" wx = 1 - interp(xs)\n",
|
||
|
" wy = 1 - interp(ys)\n",
|
||
|
" dots = 0\n",
|
||
|
" dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)\n",
|
||
|
" dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)\n",
|
||
|
" dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))\n",
|
||
|
" dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys))\n",
|
||
|
" return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale)\n",
|
||
|
"\n",
|
||
|
"def perlin_ms(octaves, width, height, grayscale, device=device):\n",
|
||
|
" out_array = [0.5] if grayscale else [0.5, 0.5, 0.5]\n",
|
||
|
" # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]\n",
|
||
|
" for i in range(1 if grayscale else 3):\n",
|
||
|
" scale = 2 ** len(octaves)\n",
|
||
|
" oct_width = width\n",
|
||
|
" oct_height = height\n",
|
||
|
" for oct in octaves:\n",
|
||
|
" p = perlin(oct_width, oct_height, scale, device)\n",
|
||
|
" out_array[i] += p * oct\n",
|
||
|
" scale //= 2\n",
|
||
|
" oct_width *= 2\n",
|
||
|
" oct_height *= 2\n",
|
||
|
" return torch.cat(out_array)\n",
|
||
|
"\n",
|
||
|
"def create_perlin_noise(octaves=[1, 1, 1, 1], width=2, height=2, grayscale=True):\n",
|
||
|
" out = perlin_ms(octaves, width, height, grayscale)\n",
|
||
|
" if grayscale:\n",
|
||
|
" out = TF.resize(size=(side_y, side_x), img=out.unsqueeze(0))\n",
|
||
|
" out = TF.to_pil_image(out.clamp(0, 1)).convert('RGB')\n",
|
||
|
" else:\n",
|
||
|
" out = out.reshape(-1, 3, out.shape[0]//3, out.shape[1])\n",
|
||
|
" out = TF.resize(size=(side_y, side_x), img=out)\n",
|
||
|
" out = TF.to_pil_image(out.clamp(0, 1).squeeze())\n",
|
||
|
"\n",
|
||
|
" out = ImageOps.autocontrast(out)\n",
|
||
|
" return out\n",
|
||
|
"\n",
|
||
|
"def fetch(url_or_path):\n",
|
||
|
" if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):\n",
|
||
|
" r = requests.get(url_or_path)\n",
|
||
|
" r.raise_for_status()\n",
|
||
|
" fd = io.BytesIO()\n",
|
||
|
" fd.write(r.content)\n",
|
||
|
" fd.seek(0)\n",
|
||
|
" return fd\n",
|
||
|
" return open(url_or_path, 'rb')\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def parse_prompt(prompt):\n",
|
||
|
" if prompt.startswith('http://') or prompt.startswith('https://'):\n",
|
||
|
" vals = prompt.rsplit(':', 2)\n",
|
||
|
" vals = [vals[0] + ':' + vals[1], *vals[2:]]\n",
|
||
|
" else:\n",
|
||
|
" vals = prompt.rsplit(':', 1)\n",
|
||
|
" vals = vals + ['', '1'][len(vals):]\n",
|
||
|
" return vals[0], float(vals[1])\n",
|
||
|
"\n",
|
||
|
"def sinc(x):\n",
|
||
|
" return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))\n",
|
||
|
"\n",
|
||
|
"def lanczos(x, a):\n",
|
||
|
" cond = torch.logical_and(-a < x, x < a)\n",
|
||
|
" out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))\n",
|
||
|
" return out / out.sum()\n",
|
||
|
"\n",
|
||
|
"def ramp(ratio, width):\n",
|
||
|
" n = math.ceil(width / ratio + 1)\n",
|
||
|
" out = torch.empty([n])\n",
|
||
|
" cur = 0\n",
|
||
|
" for i in range(out.shape[0]):\n",
|
||
|
" out[i] = cur\n",
|
||
|
" cur += ratio\n",
|
||
|
" return torch.cat([-out[1:].flip([0]), out])[1:-1]\n",
|
||
|
"\n",
|
||
|
"def resample(input, size, align_corners=True):\n",
|
||
|
" n, c, h, w = input.shape\n",
|
||
|
" dh, dw = size\n",
|
||
|
"\n",
|
||
|
" input = input.reshape([n * c, 1, h, w])\n",
|
||
|
"\n",
|
||
|
" if dh < h:\n",
|
||
|
" kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)\n",
|
||
|
" pad_h = (kernel_h.shape[0] - 1) // 2\n",
|
||
|
" input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')\n",
|
||
|
" input = F.conv2d(input, kernel_h[None, None, :, None])\n",
|
||
|
"\n",
|
||
|
" if dw < w:\n",
|
||
|
" kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)\n",
|
||
|
" pad_w = (kernel_w.shape[0] - 1) // 2\n",
|
||
|
" input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')\n",
|
||
|
" input = F.conv2d(input, kernel_w[None, None, None, :])\n",
|
||
|
"\n",
|
||
|
" input = input.reshape([n, c, h, w])\n",
|
||
|
" return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)\n",
|
||
|
"\n",
|
||
|
"class MakeCutouts(nn.Module):\n",
|
||
|
" def __init__(self, cut_size, cutn, skip_augs=False):\n",
|
||
|
" super().__init__()\n",
|
||
|
" self.cut_size = cut_size\n",
|
||
|
" self.cutn = cutn\n",
|
||
|
" self.skip_augs = skip_augs\n",
|
||
|
" self.augs = T.Compose([\n",
|
||
|
" T.RandomHorizontalFlip(p=0.5),\n",
|
||
|
" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
||
|
" T.RandomAffine(degrees=15, translate=(0.1, 0.1)),\n",
|
||
|
" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
||
|
" T.RandomPerspective(distortion_scale=0.4, p=0.7),\n",
|
||
|
" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
||
|
" T.RandomGrayscale(p=0.15),\n",
|
||
|
" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
||
|
" # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n",
|
||
|
" ])\n",
|
||
|
"\n",
|
||
|
" def forward(self, input):\n",
|
||
|
" input = T.Pad(input.shape[2]//4, fill=0)(input)\n",
|
||
|
" sideY, sideX = input.shape[2:4]\n",
|
||
|
" max_size = min(sideX, sideY)\n",
|
||
|
"\n",
|
||
|
" cutouts = []\n",
|
||
|
" for ch in range(cutn):\n",
|
||
|
" if ch > cutn - cutn//4:\n",
|
||
|
" cutout = input.clone()\n",
|
||
|
" else:\n",
|
||
|
" size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))\n",
|
||
|
" offsetx = torch.randint(0, abs(sideX - size + 1), ())\n",
|
||
|
" offsety = torch.randint(0, abs(sideY - size + 1), ())\n",
|
||
|
" cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n",
|
||
|
"\n",
|
||
|
" if not self.skip_augs:\n",
|
||
|
" cutout = self.augs(cutout)\n",
|
||
|
" cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))\n",
|
||
|
" del cutout\n",
|
||
|
"\n",
|
||
|
" cutouts = torch.cat(cutouts, dim=0)\n",
|
||
|
" return cutouts\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def spherical_dist_loss(x, y):\n",
|
||
|
" x = F.normalize(x, dim=-1)\n",
|
||
|
" y = F.normalize(y, dim=-1)\n",
|
||
|
" return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def tv_loss(input):\n",
|
||
|
" \"\"\"L2 total variation loss, as in Mahendran et al.\"\"\"\n",
|
||
|
" input = F.pad(input, (0, 1, 0, 1), 'replicate')\n",
|
||
|
" x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]\n",
|
||
|
" y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]\n",
|
||
|
" return (x_diff**2 + y_diff**2).mean([1, 2, 3])\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def range_loss(input):\n",
|
||
|
" return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def do_run():\n",
|
||
|
" loss_values = []\n",
|
||
|
" \n",
|
||
|
" if seed is not None:\n",
|
||
|
" np.random.seed(seed)\n",
|
||
|
" random.seed(seed)\n",
|
||
|
" torch.manual_seed(seed)\n",
|
||
|
" torch.cuda.manual_seed_all(seed)\n",
|
||
|
" torch.backends.cudnn.deterministic = True\n",
|
||
|
" \n",
|
||
|
" target_embeds, weights = [], []\n",
|
||
|
" model_stats = []\n",
|
||
|
" \n",
|
||
|
" for clip_model in clip_models:\n",
|
||
|
" model_stat = {\"clip_model\":None,\"target_embeds\":[],\"make_cutouts\":None,\"weights\":[]}\n",
|
||
|
" model_stat[\"clip_model\"] = clip_model\n",
|
||
|
" model_stat[\"make_cutouts\"] = MakeCutouts(clip_model.visual.input_resolution, cutn, skip_augs=skip_augs)\n",
|
||
|
"\n",
|
||
|
" for prompt in text_prompts:\n",
|
||
|
" txt, weight = parse_prompt(prompt)\n",
|
||
|
" txt = clip_model.encode_text(clip.tokenize(prompt).to(device)).float()\n",
|
||
|
"\n",
|
||
|
" if fuzzy_prompt:\n",
|
||
|
" for i in range(25):\n",
|
||
|
" model_stat[\"target_embeds\"].append((txt + torch.randn(txt.shape).cuda() * rand_mag).clamp(0,1))\n",
|
||
|
" model_stat[\"weights\"].append(weight)\n",
|
||
|
" else:\n",
|
||
|
" model_stat[\"target_embeds\"].append(txt)\n",
|
||
|
" model_stat[\"weights\"].append(weight)\n",
|
||
|
" \n",
|
||
|
" for prompt in image_prompts:\n",
|
||
|
" path, weight = parse_prompt(prompt)\n",
|
||
|
" img = Image.open(fetch(path)).convert('RGB')\n",
|
||
|
" img = TF.resize(img, min(side_x, side_y, *img.size), T.InterpolationMode.LANCZOS)\n",
|
||
|
" batch = model_stat[\"make_cutouts\"](TF.to_tensor(img).to(device).unsqueeze(0).mul(2).sub(1))\n",
|
||
|
" embed = clip_model.encode_image(normalize(batch)).float()\n",
|
||
|
" if fuzzy_prompt:\n",
|
||
|
" for i in range(25):\n",
|
||
|
" model_stat[\"target_embeds\"].append((embed + torch.randn(embed.shape).cuda() * rand_mag).clamp(0,1))\n",
|
||
|
" weights.extend([weight / cutn] * cutn)\n",
|
||
|
" else:\n",
|
||
|
" model_stat[\"target_embeds\"].append(embed)\n",
|
||
|
" model_stat[\"weights\"].extend([weight / cutn] * cutn)\n",
|
||
|
" \n",
|
||
|
" model_stat[\"target_embeds\"] = torch.cat(model_stat[\"target_embeds\"])\n",
|
||
|
" model_stat[\"weights\"] = torch.tensor(model_stat[\"weights\"], device=device)\n",
|
||
|
" if model_stat[\"weights\"].sum().abs() < 1e-3:\n",
|
||
|
" raise RuntimeError('The weights must not sum to 0.')\n",
|
||
|
" model_stat[\"weights\"] /= model_stat[\"weights\"].sum().abs()\n",
|
||
|
" model_stats.append(model_stat)\n",
|
||
|
" \n",
|
||
|
" init = None\n",
|
||
|
" if init_image is not None:\n",
|
||
|
" init = Image.open(fetch(init_image)).convert('RGB')\n",
|
||
|
" init = init.resize((side_x, side_y), Image.LANCZOS)\n",
|
||
|
" init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)\n",
|
||
|
" \n",
|
||
|
" if perlin_init:\n",
|
||
|
" if perlin_mode == 'color':\n",
|
||
|
" init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n",
|
||
|
" init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)\n",
|
||
|
" elif perlin_mode == 'gray':\n",
|
||
|
" init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)\n",
|
||
|
" init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n",
|
||
|
" else:\n",
|
||
|
" init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n",
|
||
|
" init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n",
|
||
|
" \n",
|
||
|
" # init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device)\n",
|
||
|
" init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)\n",
|
||
|
" del init2\n",
|
||
|
" \n",
|
||
|
" cur_t = None\n",
|
||
|
" \n",
|
||
|
" def cond_fn(x, t, y=None):\n",
|
||
|
" with torch.enable_grad():\n",
|
||
|
" x = x.detach().requires_grad_()\n",
|
||
|
" n = x.shape[0]\n",
|
||
|
" if use_secondary_model is True:\n",
|
||
|
" alpha = torch.tensor(diffusion.sqrt_alphas_cumprod[cur_t], device=device, dtype=torch.float32)\n",
|
||
|
" sigma = torch.tensor(diffusion.sqrt_one_minus_alphas_cumprod[cur_t], device=device, dtype=torch.float32)\n",
|
||
|
" cosine_t = alpha_sigma_to_t(alpha, sigma)\n",
|
||
|
" out = secondary_model(x, cosine_t[None].repeat([n])).pred\n",
|
||
|
" fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]\n",
|
||
|
" x_in = out * fac + x * (1 - fac)\n",
|
||
|
" x_in_grad = torch.zeros_like(x_in)\n",
|
||
|
" else:\n",
|
||
|
" my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t\n",
|
||
|
" out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y})\n",
|
||
|
" fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]\n",
|
||
|
" x_in = out['pred_xstart'] * fac + x * (1 - fac)\n",
|
||
|
" x_in_grad = torch.zeros_like(x_in)\n",
|
||
|
" for model_stat in model_stats:\n",
|
||
|
" for i in range(cutn_batches):\n",
|
||
|
" clip_in = normalize(model_stat[\"make_cutouts\"](x_in.add(1).div(2)))\n",
|
||
|
" image_embeds = model_stat[\"clip_model\"].encode_image(clip_in).float()\n",
|
||
|
" dists = spherical_dist_loss(image_embeds.unsqueeze(1), model_stat[\"target_embeds\"].unsqueeze(0))\n",
|
||
|
" dists = dists.view([cutn, n, -1])\n",
|
||
|
" losses = dists.mul(model_stat[\"weights\"]).sum(2).mean(0)\n",
|
||
|
" loss_values.append(losses.sum().item()) # log loss, probably shouldn't do per cutn_batch\n",
|
||
|
" x_in_grad += torch.autograd.grad(losses.sum() * clip_guidance_scale, x_in)[0] / cutn_batches\n",
|
||
|
" tv_losses = tv_loss(x_in)\n",
|
||
|
" if use_secondary_model is True:\n",
|
||
|
" range_losses = range_loss(out)\n",
|
||
|
" else:\n",
|
||
|
" range_losses = range_loss(out['pred_xstart'])\n",
|
||
|
" sat_losses = torch.abs(x_in - x_in.clamp(min=-1,max=1)).mean()\n",
|
||
|
" loss = tv_losses.sum() * tv_scale + range_losses.sum() * range_scale + sat_losses.sum() * sat_scale\n",
|
||
|
" if init is not None and init_scale:\n",
|
||
|
" init_losses = lpips_model(x_in, init)\n",
|
||
|
" loss = loss + init_losses.sum() * init_scale\n",
|
||
|
" x_in_grad += torch.autograd.grad(loss, x_in)[0]\n",
|
||
|
" grad = -torch.autograd.grad(x_in, x, x_in_grad)[0]\n",
|
||
|
" if clamp_grad:\n",
|
||
|
" magnitude = grad.square().mean().sqrt()\n",
|
||
|
" return grad * magnitude.clamp(max=0.05) / magnitude\n",
|
||
|
" return grad\n",
|
||
|
" \n",
|
||
|
" if model_config['timestep_respacing'].startswith('ddim'):\n",
|
||
|
" sample_fn = diffusion.ddim_sample_loop_progressive\n",
|
||
|
" else:\n",
|
||
|
" sample_fn = diffusion.p_sample_loop_progressive\n",
|
||
|
" \n",
|
||
|
" for i in range(n_batches):\n",
|
||
|
" cur_t = diffusion.num_timesteps - skip_timesteps - 1\n",
|
||
|
" total_steps = cur_t\n",
|
||
|
" \n",
|
||
|
" if model_config['timestep_respacing'].startswith('ddim'):\n",
|
||
|
" samples = sample_fn(\n",
|
||
|
" model,\n",
|
||
|
" (batch_size, 3, side_y, side_x),\n",
|
||
|
" clip_denoised=clip_denoised,\n",
|
||
|
" model_kwargs={},\n",
|
||
|
" cond_fn=cond_fn,\n",
|
||
|
" progress=True,\n",
|
||
|
" skip_timesteps=skip_timesteps,\n",
|
||
|
" init_image=init,\n",
|
||
|
" randomize_class=randomize_class,\n",
|
||
|
" eta=eta,\n",
|
||
|
" )\n",
|
||
|
" else:\n",
|
||
|
" samples = sample_fn(\n",
|
||
|
" model,\n",
|
||
|
" (batch_size, 3, side_y, side_x),\n",
|
||
|
" clip_denoised=clip_denoised,\n",
|
||
|
" model_kwargs={},\n",
|
||
|
" cond_fn=cond_fn,\n",
|
||
|
" progress=True,\n",
|
||
|
" skip_timesteps=skip_timesteps,\n",
|
||
|
" init_image=init,\n",
|
||
|
" randomize_class=randomize_class,\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" for j, sample in enumerate(samples):\n",
|
||
|
" display.clear_output(wait=True)\n",
|
||
|
" cur_t -= 1\n",
|
||
|
" intermediateStep = False\n",
|
||
|
" if steps_per_checkpoint is not None:\n",
|
||
|
" if j % steps_per_checkpoint == 0 and j > 0:\n",
|
||
|
" intermediateStep = True\n",
|
||
|
" elif j in intermediate_saves:\n",
|
||
|
" intermediateStep = True\n",
|
||
|
" if j % display_rate == 0 or cur_t == -1 or intermediateStep == True:\n",
|
||
|
" for k, image in enumerate(sample['pred_xstart']):\n",
|
||
|
" tqdm.write(f'Batch {i}, step {j}, output {k}:')\n",
|
||
|
" current_time = datetime.now().strftime('%y%m%d-%H%M%S_%f')\n",
|
||
|
" percent = math.ceil(j/total_steps*100)\n",
|
||
|
" if n_batches > 0:\n",
|
||
|
" #if intermediates are saved to the subfolder, don't append a step or percentage to the name\n",
|
||
|
" if cur_t == -1 and intermediates_in_subfolder is True:\n",
|
||
|
" filename = f'{batch_name}({batchNum})_{i:04}.png'\n",
|
||
|
" else:\n",
|
||
|
" #If we're working with percentages, append it\n",
|
||
|
" if steps_per_checkpoint is not None:\n",
|
||
|
" filename = f'{batch_name}({batchNum})_{i:04}-{percent:02}%.png'\n",
|
||
|
" # Or else, iIf we're working with specific steps, append those\n",
|
||
|
" else:\n",
|
||
|
" filename = f'{batch_name}({batchNum})_{i:04}-{j:03}.png'\n",
|
||
|
" image = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))\n",
|
||
|
" image.save('progress.png')\n",
|
||
|
" display.display(display.Image('progress.png'))\n",
|
||
|
" if steps_per_checkpoint is not None:\n",
|
||
|
" if j % steps_per_checkpoint == 0 and j > 0:\n",
|
||
|
" if intermediates_in_subfolder is True:\n",
|
||
|
" image.save(f'{partialFolder}/{filename}')\n",
|
||
|
" else:\n",
|
||
|
" image.save(f'{batchFolder}/{filename}')\n",
|
||
|
" else:\n",
|
||
|
" if j in intermediate_saves:\n",
|
||
|
" if intermediates_in_subfolder is True:\n",
|
||
|
" image.save(f'{partialFolder}/{filename}')\n",
|
||
|
" else:\n",
|
||
|
" image.save(f'{batchFolder}/{filename}')\n",
|
||
|
" if cur_t == -1:\n",
|
||
|
" if i == 0:\n",
|
||
|
" save_settings()\n",
|
||
|
" image.save(f'{batchFolder}/{filename}')\n",
|
||
|
" \n",
|
||
|
" plt.plot(np.array(loss_values), 'r')\n",
|
||
|
"\n",
|
||
|
"def save_settings():\n",
|
||
|
" setting_list = {\n",
|
||
|
" 'text_prompts': text_prompts,\n",
|
||
|
" 'image_prompts': image_prompts,\n",
|
||
|
" 'clip_guidance_scale': clip_guidance_scale,\n",
|
||
|
" 'tv_scale': tv_scale,\n",
|
||
|
" 'range_scale': range_scale,\n",
|
||
|
" 'sat_scale': sat_scale,\n",
|
||
|
" 'cutn': cutn,\n",
|
||
|
" 'cutn_batches': cutn_batches,\n",
|
||
|
" 'init_image': init_image,\n",
|
||
|
" 'init_scale': init_scale,\n",
|
||
|
" 'skip_timesteps': skip_timesteps,\n",
|
||
|
" 'perlin_init': perlin_init,\n",
|
||
|
" 'perlin_mode': perlin_mode,\n",
|
||
|
" 'skip_augs': skip_augs,\n",
|
||
|
" 'randomize_class': randomize_class,\n",
|
||
|
" 'clip_denoised': clip_denoised,\n",
|
||
|
" 'clamp_grad': clamp_grad,\n",
|
||
|
" 'seed': seed,\n",
|
||
|
" 'fuzzy_prompt': fuzzy_prompt,\n",
|
||
|
" 'rand_mag': rand_mag,\n",
|
||
|
" 'eta': eta,\n",
|
||
|
" 'width': width,\n",
|
||
|
" 'height': height,\n",
|
||
|
" 'diffusion_model': diffusion_model,\n",
|
||
|
" 'use_secondary_model': use_secondary_model,\n",
|
||
|
" 'timestep_respacing': timestep_respacing,\n",
|
||
|
" 'timestep_respacing': timestep_respacing,\n",
|
||
|
" 'diffusion_steps': diffusion_steps,\n",
|
||
|
" 'ViTB32': ViTB32,\n",
|
||
|
" 'ViTB16': ViTB16,\n",
|
||
|
" 'RN101': RN101,\n",
|
||
|
" 'RN50': RN50,\n",
|
||
|
" 'RN50x4': RN50x4,\n",
|
||
|
" 'RN50x16': RN50x16,\n",
|
||
|
" }\n",
|
||
|
" print('Settings:', setting_list)\n",
|
||
|
" with open(f\"{batchFolder}/{batch_name}({batchNum})_settings.txt\", \"w+\") as f: #save settings\n",
|
||
|
" json.dump(setting_list, f, ensure_ascii=False, indent=4)\n",
|
||
|
" "
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"cellView": "form",
|
||
|
"id": "TI4oAu0N4ksZ"
|
||
|
},
|
||
|
"source": [
|
||
|
"#@title 2.3 Define the secondary diffusion model\n",
|
||
|
"\n",
|
||
|
"def append_dims(x, n):\n",
|
||
|
" return x[(Ellipsis, *(None,) * (n - x.ndim))]\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def expand_to_planes(x, shape):\n",
|
||
|
" return append_dims(x, len(shape)).repeat([1, 1, *shape[2:]])\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def alpha_sigma_to_t(alpha, sigma):\n",
|
||
|
" return torch.atan2(sigma, alpha) * 2 / math.pi\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def t_to_alpha_sigma(t):\n",
|
||
|
" return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"@dataclass\n",
|
||
|
"class DiffusionOutput:\n",
|
||
|
" v: torch.Tensor\n",
|
||
|
" pred: torch.Tensor\n",
|
||
|
" eps: torch.Tensor\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"class ConvBlock(nn.Sequential):\n",
|
||
|
" def __init__(self, c_in, c_out):\n",
|
||
|
" super().__init__(\n",
|
||
|
" nn.Conv2d(c_in, c_out, 3, padding=1),\n",
|
||
|
" nn.ReLU(inplace=True),\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"class SkipBlock(nn.Module):\n",
|
||
|
" def __init__(self, main, skip=None):\n",
|
||
|
" super().__init__()\n",
|
||
|
" self.main = nn.Sequential(*main)\n",
|
||
|
" self.skip = skip if skip else nn.Identity()\n",
|
||
|
"\n",
|
||
|
" def forward(self, input):\n",
|
||
|
" return torch.cat([self.main(input), self.skip(input)], dim=1)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"class FourierFeatures(nn.Module):\n",
|
||
|
" def __init__(self, in_features, out_features, std=1.):\n",
|
||
|
" super().__init__()\n",
|
||
|
" assert out_features % 2 == 0\n",
|
||
|
" self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std)\n",
|
||
|
"\n",
|
||
|
" def forward(self, input):\n",
|
||
|
" f = 2 * math.pi * input @ self.weight.T\n",
|
||
|
" return torch.cat([f.cos(), f.sin()], dim=-1)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"class SecondaryDiffusionImageNet(nn.Module):\n",
|
||
|
" def __init__(self):\n",
|
||
|
" super().__init__()\n",
|
||
|
" c = 64 # The base channel count\n",
|
||
|
"\n",
|
||
|
" self.timestep_embed = FourierFeatures(1, 16)\n",
|
||
|
"\n",
|
||
|
" self.net = nn.Sequential(\n",
|
||
|
" ConvBlock(3 + 16, c),\n",
|
||
|
" ConvBlock(c, c),\n",
|
||
|
" SkipBlock([\n",
|
||
|
" nn.AvgPool2d(2),\n",
|
||
|
" ConvBlock(c, c * 2),\n",
|
||
|
" ConvBlock(c * 2, c * 2),\n",
|
||
|
" SkipBlock([\n",
|
||
|
" nn.AvgPool2d(2),\n",
|
||
|
" ConvBlock(c * 2, c * 4),\n",
|
||
|
" ConvBlock(c * 4, c * 4),\n",
|
||
|
" SkipBlock([\n",
|
||
|
" nn.AvgPool2d(2),\n",
|
||
|
" ConvBlock(c * 4, c * 8),\n",
|
||
|
" ConvBlock(c * 8, c * 4),\n",
|
||
|
" nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n",
|
||
|
" ]),\n",
|
||
|
" ConvBlock(c * 8, c * 4),\n",
|
||
|
" ConvBlock(c * 4, c * 2),\n",
|
||
|
" nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n",
|
||
|
" ]),\n",
|
||
|
" ConvBlock(c * 4, c * 2),\n",
|
||
|
" ConvBlock(c * 2, c),\n",
|
||
|
" nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n",
|
||
|
" ]),\n",
|
||
|
" ConvBlock(c * 2, c),\n",
|
||
|
" nn.Conv2d(c, 3, 3, padding=1),\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" def forward(self, input, t):\n",
|
||
|
" timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)\n",
|
||
|
" v = self.net(torch.cat([input, timestep_embed], dim=1))\n",
|
||
|
" alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))\n",
|
||
|
" pred = input * alphas - v * sigmas\n",
|
||
|
" eps = input * sigmas + v * alphas\n",
|
||
|
" return DiffusionOutput(v, pred, eps)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"class SecondaryDiffusionImageNet2(nn.Module):\n",
|
||
|
" def __init__(self):\n",
|
||
|
" super().__init__()\n",
|
||
|
" c = 64 # The base channel count\n",
|
||
|
" cs = [c, c * 2, c * 2, c * 4, c * 4, c * 8]\n",
|
||
|
"\n",
|
||
|
" self.timestep_embed = FourierFeatures(1, 16)\n",
|
||
|
" self.down = nn.AvgPool2d(2)\n",
|
||
|
" self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)\n",
|
||
|
"\n",
|
||
|
" self.net = nn.Sequential(\n",
|
||
|
" ConvBlock(3 + 16, cs[0]),\n",
|
||
|
" ConvBlock(cs[0], cs[0]),\n",
|
||
|
" SkipBlock([\n",
|
||
|
" self.down,\n",
|
||
|
" ConvBlock(cs[0], cs[1]),\n",
|
||
|
" ConvBlock(cs[1], cs[1]),\n",
|
||
|
" SkipBlock([\n",
|
||
|
" self.down,\n",
|
||
|
" ConvBlock(cs[1], cs[2]),\n",
|
||
|
" ConvBlock(cs[2], cs[2]),\n",
|
||
|
" SkipBlock([\n",
|
||
|
" self.down,\n",
|
||
|
" ConvBlock(cs[2], cs[3]),\n",
|
||
|
" ConvBlock(cs[3], cs[3]),\n",
|
||
|
" SkipBlock([\n",
|
||
|
" self.down,\n",
|
||
|
" ConvBlock(cs[3], cs[4]),\n",
|
||
|
" ConvBlock(cs[4], cs[4]),\n",
|
||
|
" SkipBlock([\n",
|
||
|
" self.down,\n",
|
||
|
" ConvBlock(cs[4], cs[5]),\n",
|
||
|
" ConvBlock(cs[5], cs[5]),\n",
|
||
|
" ConvBlock(cs[5], cs[5]),\n",
|
||
|
" ConvBlock(cs[5], cs[4]),\n",
|
||
|
" self.up,\n",
|
||
|
" ]),\n",
|
||
|
" ConvBlock(cs[4] * 2, cs[4]),\n",
|
||
|
" ConvBlock(cs[4], cs[3]),\n",
|
||
|
" self.up,\n",
|
||
|
" ]),\n",
|
||
|
" ConvBlock(cs[3] * 2, cs[3]),\n",
|
||
|
" ConvBlock(cs[3], cs[2]),\n",
|
||
|
" self.up,\n",
|
||
|
" ]),\n",
|
||
|
" ConvBlock(cs[2] * 2, cs[2]),\n",
|
||
|
" ConvBlock(cs[2], cs[1]),\n",
|
||
|
" self.up,\n",
|
||
|
" ]),\n",
|
||
|
" ConvBlock(cs[1] * 2, cs[1]),\n",
|
||
|
" ConvBlock(cs[1], cs[0]),\n",
|
||
|
" self.up,\n",
|
||
|
" ]),\n",
|
||
|
" ConvBlock(cs[0] * 2, cs[0]),\n",
|
||
|
" nn.Conv2d(cs[0], 3, 3, padding=1),\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" def forward(self, input, t):\n",
|
||
|
" timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)\n",
|
||
|
" v = self.net(torch.cat([input, timestep_embed], dim=1))\n",
|
||
|
" alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))\n",
|
||
|
" pred = input * alphas - v * sigmas\n",
|
||
|
" eps = input * sigmas + v * alphas\n",
|
||
|
" return DiffusionOutput(v, pred, eps)\n"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "CQVtY1Ixnqx4"
|
||
|
},
|
||
|
"source": [
|
||
|
"# 3. Diffusion and CLIP model settings"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "Fpbody2NCR7w",
|
||
|
"cellView": "form"
|
||
|
},
|
||
|
"source": [
|
||
|
"diffusion_model = \"512x512_diffusion_uncond_finetune_008100\" #@param [\"256x256_diffusion_uncond\", \"512x512_diffusion_uncond_finetune_008100\"]\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"if diffusion_model == '256x256_diffusion_uncond':\n",
|
||
|
" !wget --continue 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt' -P {model_path}\n",
|
||
|
"elif diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n",
|
||
|
" !wget --continue 'http://batbot.tv/ai/models/guided-diffusion/512x512_diffusion_uncond_finetune_008100.pt' -P {model_path}\n",
|
||
|
"\n",
|
||
|
"use_secondary_model = True #@param {type: 'boolean'}\n",
|
||
|
"\n",
|
||
|
"# Download the secondary diffusion model v2\n",
|
||
|
"# SHA-256: 983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a \n",
|
||
|
"if use_secondary_model == True:\n",
|
||
|
" !wget --continue 'https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth' -P {model_path}\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"timestep_respacing = 'ddim250' #@param ['25','50','100','150','250','500','1000','ddim25','ddim50', 'ddim75', 'ddim100','ddim150','ddim250','ddim500','ddim1000'] \n",
|
||
|
"diffusion_steps = 1000 #@param {type: 'number'}\n",
|
||
|
"ViTB32 = True #@param{type:\"boolean\"}\n",
|
||
|
"ViTB16 = True #@param{type:\"boolean\"}\n",
|
||
|
"RN101 = False #@param{type:\"boolean\"}\n",
|
||
|
"RN50 = False #@param{type:\"boolean\"}\n",
|
||
|
"RN50x4 = False #@param{type:\"boolean\"}\n",
|
||
|
"RN50x16 = False #@param{type:\"boolean\"}\n",
|
||
|
"# ViTL = True #@param{type:\"boolean\"}\n",
|
||
|
"#@markdown *RN50x16 for A100 only*\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"model_config = model_and_diffusion_defaults()\n",
|
||
|
"if diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n",
|
||
|
" model_config.update({\n",
|
||
|
" 'attention_resolutions': '32, 16, 8',\n",
|
||
|
" 'class_cond': False,\n",
|
||
|
" 'diffusion_steps': diffusion_steps,\n",
|
||
|
" 'rescale_timesteps': True,\n",
|
||
|
" 'timestep_respacing': timestep_respacing,\n",
|
||
|
" 'image_size': 512,\n",
|
||
|
" 'learn_sigma': True,\n",
|
||
|
" 'noise_schedule': 'linear',\n",
|
||
|
" 'num_channels': 256,\n",
|
||
|
" 'num_head_channels': 64,\n",
|
||
|
" 'num_res_blocks': 2,\n",
|
||
|
" 'resblock_updown': True,\n",
|
||
|
" 'use_checkpoint': True,\n",
|
||
|
" 'use_fp16': True,\n",
|
||
|
" 'use_scale_shift_norm': True,\n",
|
||
|
" })\n",
|
||
|
"elif diffusion_model == '256x256_diffusion_uncond':\n",
|
||
|
" model_config.update({\n",
|
||
|
" 'attention_resolutions': '32, 16, 8',\n",
|
||
|
" 'class_cond': False,\n",
|
||
|
" 'diffusion_steps': diffusion_steps,\n",
|
||
|
" 'rescale_timesteps': True,\n",
|
||
|
" 'timestep_respacing': timestep_respacing,\n",
|
||
|
" 'image_size': 256,\n",
|
||
|
" 'learn_sigma': True,\n",
|
||
|
" 'noise_schedule': 'linear',\n",
|
||
|
" 'num_channels': 256,\n",
|
||
|
" 'num_head_channels': 64,\n",
|
||
|
" 'num_res_blocks': 2,\n",
|
||
|
" 'resblock_updown': True,\n",
|
||
|
" 'use_checkpoint': True,\n",
|
||
|
" 'use_fp16': True,\n",
|
||
|
" 'use_scale_shift_norm': True,\n",
|
||
|
" })\n",
|
||
|
"\n",
|
||
|
"secondary_model_ver = 2\n",
|
||
|
"model_default = model_config['image_size']\n",
|
||
|
"\n",
|
||
|
"model, diffusion = create_model_and_diffusion(**model_config)\n",
|
||
|
"model.load_state_dict(torch.load(f'{model_path}/{diffusion_model}.pt', map_location='cpu'))\n",
|
||
|
"model.requires_grad_(False).eval().to(device)\n",
|
||
|
"for name, param in model.named_parameters():\n",
|
||
|
" if 'qkv' in name or 'norm' in name or 'proj' in name:\n",
|
||
|
" param.requires_grad_()\n",
|
||
|
"if model_config['use_fp16']:\n",
|
||
|
" model.convert_to_fp16()\n",
|
||
|
"\n",
|
||
|
"if secondary_model_ver == 2:\n",
|
||
|
" secondary_model = SecondaryDiffusionImageNet2()\n",
|
||
|
" secondary_model.load_state_dict(torch.load(f'{model_path}/secondary_model_imagenet_2.pth', map_location='cpu'))\n",
|
||
|
"secondary_model.eval().requires_grad_(False).to(device)\n",
|
||
|
"\n",
|
||
|
"clip_models = []\n",
|
||
|
"if ViTB32 is True: clip_models.append(clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
|
||
|
"if ViTB16 is True: clip_models.append(clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device) ) \n",
|
||
|
"if RN50 is True: clip_models.append(clip.load('RN50', jit=False)[0].eval().requires_grad_(False).to(device))\n",
|
||
|
"if RN50x4 is True: clip_models.append(clip.load('RN50x4', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
|
||
|
"if RN50x16 is True: clip_models.append(clip.load('RN50x16', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
|
||
|
"if RN101 is True: clip_models.append(clip.load('RN101', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
|
||
|
"# if ViTL is True: clip_models.append(load('ViT-L', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
|
||
|
"\n",
|
||
|
"normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])\n",
|
||
|
"lpips_model = lpips.LPIPS(net='vgg').to(device)"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "kjtsXaszn-bB"
|
||
|
},
|
||
|
"source": [
|
||
|
"# 4. Settings"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "U0PwzFZbLfcy",
|
||
|
"cellView": "form"
|
||
|
},
|
||
|
"source": [
|
||
|
"#@markdown ####**Basic Settings:**\n",
|
||
|
"\n",
|
||
|
"clip_guidance_scale = 5000 #@param{type: 'number'}\n",
|
||
|
"tv_scale = 8000#@param{type: 'number'}\n",
|
||
|
"range_scale = 150 #@param{type: 'number'}\n",
|
||
|
"sat_scale = 0 #@param{type: 'number'}\n",
|
||
|
"cutn = 16 #@param{type: 'number'}\n",
|
||
|
"cutn_batches = 2 #@param{type: 'number'}\n",
|
||
|
"\n",
|
||
|
"init_image = '' #@param{type: 'string'}\n",
|
||
|
"init_scale = 200#@param{type: 'number'}\n",
|
||
|
"skip_timesteps = 0 #@param{type: 'number'}\n",
|
||
|
"\n",
|
||
|
"#@markdown Size must be multiple of 64. Leave as `model_default` for default sizes. \n",
|
||
|
"width = model_default#@param{type: 'raw'}\n",
|
||
|
"height = model_default#@param{type: 'raw'}\n",
|
||
|
"\n",
|
||
|
"#@markdown ---\n",
|
||
|
"\n",
|
||
|
"#@markdown ####**Saving:**\n",
|
||
|
"batch_name = 'Test' #@param{type: 'string'}\n",
|
||
|
"intermediate_saves = 5#@param{type: 'raw'}\n",
|
||
|
"intermediates_in_subfolder = True #@param{type: 'boolean'}\n",
|
||
|
"#@markdown Intermediate steps will save a copy at your specified intervals. You can either format it as a single integer or a list of specific steps \n",
|
||
|
"\n",
|
||
|
"#@markdown A value of `2` will save a copy at 33% and 66%. 0 will save none.\n",
|
||
|
"\n",
|
||
|
"#@markdown A value of `[5, 9, 34, 45]` will save at steps 5, 9, 34, and 45. (Make sure to include the brackets)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"#@markdown ---\n",
|
||
|
"\n",
|
||
|
"#@markdown ####**Advanced Settings:**\n",
|
||
|
"#@markdown *Perlin init will replace your init, so uncheck if using one.*\n",
|
||
|
"\n",
|
||
|
"perlin_init = True #@param{type: 'boolean'}\n",
|
||
|
"perlin_mode = 'mixed' \n",
|
||
|
"\n",
|
||
|
"skip_augs = False #@param{type: 'boolean'}\n",
|
||
|
"randomize_class = True #@param{type: 'boolean'}\n",
|
||
|
"clip_denoised = False #@param{type: 'boolean'}\n",
|
||
|
"clamp_grad = True #@param{type: 'boolean'}\n",
|
||
|
"\n",
|
||
|
"seed = 'random_seed' #@param{type: 'string'}\n",
|
||
|
"\n",
|
||
|
"fuzzy_prompt = False #@param{type: 'boolean'}\n",
|
||
|
"rand_mag = 0.05 #@param{type: 'number'}\n",
|
||
|
"eta = 1#@param{type: 'number'}\n",
|
||
|
"\n",
|
||
|
"if type(intermediate_saves) is not list:\n",
|
||
|
" steps_per_checkpoint = math.floor((diffusion.num_timesteps - skip_timesteps - 1) // (intermediate_saves+1))\n",
|
||
|
" steps_per_checkpoint = steps_per_checkpoint if steps_per_checkpoint > 0 else 1\n",
|
||
|
" print(f'Will save every {steps_per_checkpoint} steps')\n",
|
||
|
"else:\n",
|
||
|
" steps_per_checkpoint = None\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"if init_image == '':\n",
|
||
|
" init_image = None\n",
|
||
|
"\n",
|
||
|
"side_x = width;\n",
|
||
|
"side_y = height;\n",
|
||
|
"\n",
|
||
|
"#Make folder for batch\n",
|
||
|
"batchFolder = f'{outDirPath}/{batch_name}'\n",
|
||
|
"createPath(batchFolder)\n",
|
||
|
"\n",
|
||
|
"partialFolder = f'{batchFolder}/partials'\n",
|
||
|
"createPath(partialFolder)"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "XIwh5RvNpk4K"
|
||
|
},
|
||
|
"source": [
|
||
|
"##Prompts"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "BGBzhk3dpcGO"
|
||
|
},
|
||
|
"source": [
|
||
|
"text_prompts = [\n",
|
||
|
" \"A lost treasure found in the depths of atlantis by greg ruktowski, trending on artstation\",\n",
|
||
|
"]\n",
|
||
|
"\n",
|
||
|
"image_prompts = [ \n",
|
||
|
" # 'mona.jpg',\n",
|
||
|
"]"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "Nf9hTc8YLoLx"
|
||
|
},
|
||
|
"source": [
|
||
|
"# 5. Diffuse!"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "LHLiO56OfwgD",
|
||
|
"cellView": "form"
|
||
|
},
|
||
|
"source": [
|
||
|
"#@title Do the Run!\n",
|
||
|
"\n",
|
||
|
"display_rate = 2 #@param{type: 'number'}\n",
|
||
|
"n_batches = 5#@param{type: 'number'}\n",
|
||
|
"batch_size = 1 \n",
|
||
|
"\n",
|
||
|
"batchNum = len(glob(batchFolder+\"/*.txt\"))\n",
|
||
|
"\n",
|
||
|
"while path.isfile(f\"{batchFolder}/{batch_name}({batchNum})_settings.txt\") is True or path.isfile(f\"{batchFolder}/{batch_name}-{batchNum}_settings.txt\") is True:\n",
|
||
|
" batchNum += 1\n",
|
||
|
"\n",
|
||
|
"if seed == 'random_seed':\n",
|
||
|
" seed = random.randint(0, 2**32)\n",
|
||
|
"else:\n",
|
||
|
" seed = int(seed)\n",
|
||
|
"\n",
|
||
|
"gc.collect()\n",
|
||
|
"torch.cuda.empty_cache()\n",
|
||
|
"try: \n",
|
||
|
" do_run()\n",
|
||
|
"except KeyboardInterrupt:\n",
|
||
|
" pass\n",
|
||
|
"finally:\n",
|
||
|
" print('seed', seed)\n",
|
||
|
" gc.collect()\n",
|
||
|
" torch.cuda.empty_cache()"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
}
|
||
|
]
|
||
|
}
|