You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
3143 lines
155 KiB
3143 lines
155 KiB
{ |
|
"cells": [ |
|
{ |
|
"cell_type": "markdown", |
|
"metadata": { |
|
"id": "view-in-github", |
|
"colab_type": "text" |
|
}, |
|
"source": [ |
|
"<a href=\"https://colab.research.google.com/github/alembics/disco-diffusion/blob/main/Disco_Diffusion.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" |
|
] |
|
}, |
|
{ |
|
"cell_type": "markdown", |
|
"metadata": {}, |
|
"source": [ |
|
"# Disco Diffusion v5 - Now with 3D animation\n", |
|
"\n", |
|
"In case of confusion, Disco is the name of this notebook edit. The diffusion model in use is Katherine Crowson's fine-tuned 512x512 model\n", |
|
"\n", |
|
"For issues, join the [Disco Diffusion Discord](https://discord.gg/msEZBy4HxA) or message us on twitter at [@somnai_dreams](https://twitter.com/somnai_dreams) or [@gandamu](https://twitter.com/gandamu_ml)" |
|
] |
|
}, |
|
{ |
|
"cell_type": "markdown", |
|
"metadata": { |
|
"id": "1YwMUyt9LHG1" |
|
}, |
|
"source": [ |
|
"### Credits & Changelog \u2b07\ufe0f" |
|
] |
|
}, |
|
{ |
|
"cell_type": "markdown", |
|
"metadata": {}, |
|
"source": [ |
|
"#### Credits\n", |
|
"\n", |
|
"Original notebook by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). It uses either OpenAI's 256x256 unconditional ImageNet or Katherine Crowson's fine-tuned 512x512 diffusion model (https://github.com/openai/guided-diffusion), together with CLIP (https://github.com/openai/CLIP) to connect text prompts with images.\n", |
|
"\n", |
|
"Modified by Daniel Russell (https://github.com/russelldc, https://twitter.com/danielrussruss) to include (hopefully) optimal params for quick generations in 15-100 timesteps rather than 1000, as well as more robust augmentations.\n", |
|
"\n", |
|
"Further improvements from Dango233 and nsheppard helped improve the quality of diffusion in general, and especially so for shorter runs like this notebook aims to achieve.\n", |
|
"\n", |
|
"Vark added code to load in multiple Clip models at once, which all prompts are evaluated against, which may greatly improve accuracy.\n", |
|
"\n", |
|
"The latest zoom, pan, rotation, and keyframes features were taken from Chigozie Nri's VQGAN Zoom Notebook (https://github.com/chigozienri, https://twitter.com/chigozienri)\n", |
|
"\n", |
|
"Advanced DangoCutn Cutout method is also from Dango223.\n", |
|
"\n", |
|
"--\n", |
|
"\n", |
|
"Disco:\n", |
|
"\n", |
|
"Somnai (https://twitter.com/Somnai_dreams) added Diffusion Animation techniques, QoL improvements and various implementations of tech and techniques, mostly listed in the changelog below.\n", |
|
"\n", |
|
"3D animation implementation added by Adam Letts (https://twitter.com/gandamu_ml) in collaboration with Somnai." |
|
] |
|
}, |
|
{ |
|
"cell_type": "markdown", |
|
"metadata": {}, |
|
"source": [ |
|
"#### License" |
|
] |
|
}, |
|
{ |
|
"cell_type": "markdown", |
|
"metadata": {}, |
|
"source": [ |
|
"@title Licensed under the MIT License\n", |
|
"\n", |
|
"Copyright (c) 2021 Katherine Crowson \n", |
|
"\n", |
|
"Permission is hereby granted, free of charge, to any person obtaining a copy\n", |
|
"of this software and associated documentation files (the \"Software\"), to deal\n", |
|
"in the Software without restriction, including without limitation the rights\n", |
|
"to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", |
|
"copies of the Software, and to permit persons to whom the Software is\n", |
|
"furnished to do so, subject to the following conditions:\n", |
|
"\n", |
|
"The above copyright notice and this permission notice shall be included in\n", |
|
"all copies or substantial portions of the Software.\n", |
|
"\n", |
|
"THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", |
|
"IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", |
|
"FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", |
|
"AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", |
|
"LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", |
|
"OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n", |
|
"THE SOFTWARE.\n", |
|
"\n", |
|
"--\n", |
|
"\n", |
|
"MIT License\n", |
|
"\n", |
|
"Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)\n", |
|
"\n", |
|
"Permission is hereby granted, free of charge, to any person obtaining a copy\n", |
|
"of this software and associated documentation files (the \"Software\"), to deal\n", |
|
"in the Software without restriction, including without limitation the rights\n", |
|
"to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", |
|
"copies of the Software, and to permit persons to whom the Software is\n", |
|
"furnished to do so, subject to the following conditions:\n", |
|
"\n", |
|
"The above copyright notice and this permission notice shall be included in all\n", |
|
"copies or substantial portions of the Software.\n", |
|
"\n", |
|
"THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", |
|
"IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", |
|
"FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", |
|
"AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", |
|
"LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", |
|
"OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n", |
|
"SOFTWARE.\n", |
|
"\n", |
|
"--\n", |
|
"\n", |
|
"Licensed under the MIT License\n", |
|
"\n", |
|
"Copyright (c) 2021 Maxwell Ingham\n", |
|
"\n", |
|
"Copyright (c) 2022 Adam Letts \n", |
|
"\n", |
|
"Permission is hereby granted, free of charge, to any person obtaining a copy\n", |
|
"of this software and associated documentation files (the \"Software\"), to deal\n", |
|
"in the Software without restriction, including without limitation the rights\n", |
|
"to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", |
|
"copies of the Software, and to permit persons to whom the Software is\n", |
|
"furnished to do so, subject to the following conditions:\n", |
|
"\n", |
|
"The above copyright notice and this permission notice shall be included in\n", |
|
"all copies or substantial portions of the Software.\n", |
|
"\n", |
|
"THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", |
|
"IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", |
|
"FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", |
|
"AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", |
|
"LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", |
|
"OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n", |
|
"THE SOFTWARE." |
|
] |
|
}, |
|
{ |
|
"cell_type": "markdown", |
|
"metadata": {}, |
|
"source": [ |
|
"#### Changelog" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"metadata": {}, |
|
"source": [ |
|
"#@title <- View Changelog\n", |
|
"skip_for_run_all = True #@param {type: 'boolean'}\n", |
|
"\n", |
|
"if skip_for_run_all == False:\n", |
|
" print(\n", |
|
" '''\n", |
|
" v1 Update: Oct 29th 2021 - Somnai\n", |
|
"\n", |
|
" QoL improvements added by Somnai (@somnai_dreams), including user friendly UI, settings+prompt saving and improved google drive folder organization.\n", |
|
"\n", |
|
" v1.1 Update: Nov 13th 2021 - Somnai\n", |
|
"\n", |
|
" Now includes sizing options, intermediate saves and fixed image prompts and perlin inits. unexposed batch option since it doesn't work\n", |
|
"\n", |
|
" v2 Update: Nov 22nd 2021 - Somnai\n", |
|
"\n", |
|
" Initial addition of Katherine Crowson's Secondary Model Method (https://colab.research.google.com/drive/1mpkrhOjoyzPeSWy2r7T8EYRaU7amYOOi#scrollTo=X5gODNAMEUCR)\n", |
|
"\n", |
|
" Noticed settings were saving with the wrong name so corrected it. Let me know if you preferred the old scheme.\n", |
|
"\n", |
|
" v3 Update: Dec 24th 2021 - Somnai\n", |
|
"\n", |
|
" Implemented Dango's advanced cutout method\n", |
|
"\n", |
|
" Added SLIP models, thanks to NeuralDivergent\n", |
|
"\n", |
|
" Fixed issue with NaNs resulting in black images, with massive help and testing from @Softology\n", |
|
"\n", |
|
" Perlin now changes properly within batches (not sure where this perlin_regen code came from originally, but thank you)\n", |
|
"\n", |
|
" v4 Update: Jan 2021 - Somnai\n", |
|
"\n", |
|
" Implemented Diffusion Zooming\n", |
|
"\n", |
|
" Added Chigozie keyframing\n", |
|
"\n", |
|
" Made a bunch of edits to processes\n", |
|
" \n", |
|
" v4.1 Update: Jan 14th 2021 - Somnai\n", |
|
"\n", |
|
" Added video input mode\n", |
|
"\n", |
|
" Added license that somehow went missing\n", |
|
"\n", |
|
" Added improved prompt keyframing, fixed image_prompts and multiple prompts\n", |
|
"\n", |
|
" Improved UI\n", |
|
"\n", |
|
" Significant under the hood cleanup and improvement\n", |
|
"\n", |
|
" Refined defaults for each mode\n", |
|
"\n", |
|
" Added latent-diffusion SuperRes for sharpening\n", |
|
"\n", |
|
" Added resume run mode\n", |
|
"\n", |
|
" v4.9 Update: Feb 5th 2022 - gandamu / Adam Letts\n", |
|
"\n", |
|
" Added 3D\n", |
|
"\n", |
|
" Added brightness corrections to prevent animation from steadily going dark over time\n", |
|
"\n", |
|
" v4.91 Update: Feb 19th 2022 - gandamu / Adam Letts\n", |
|
"\n", |
|
" Cleaned up 3D implementation and made associated args accessible via Colab UI elements\n", |
|
"\n", |
|
" v4.92 Update: Feb 20th 2022 - gandamu / Adam Letts\n", |
|
"\n", |
|
" Separated transform code\n", |
|
"\n", |
|
" v5.01 Update: Match 10th 2022 - gandamu / Adam Letts\n", |
|
"\n", |
|
" IPython magic commands replaced by Python code\n", |
|
"\n", |
|
" '''\n", |
|
" )\n" |
|
], |
|
"outputs": [], |
|
"execution_count": null |
|
}, |
|
{ |
|
"cell_type": "markdown", |
|
"metadata": { |
|
"id": "XTu6AjLyFQUq" |
|
}, |
|
"source": [ |
|
"# Tutorial" |
|
] |
|
}, |
|
{ |
|
"cell_type": "markdown", |
|
"metadata": {}, |
|
"source": [ |
|
"**Diffusion settings (Defaults are heavily outdated)**\n", |
|
"---\n", |
|
"\n", |
|
"This section is outdated as of v2\n", |
|
"\n", |
|
"Setting | Description | Default\n", |
|
"--- | --- | ---\n", |
|
"**Your vision:**\n", |
|
"`text_prompts` | A description of what you'd like the machine to generate. Think of it like writing the caption below your image on a website. | N/A\n", |
|
"`image_prompts` | Think of these images more as a description of their contents. | N/A\n", |
|
"**Image quality:**\n", |
|
"`clip_guidance_scale` | Controls how much the image should look like the prompt. | 1000\n", |
|
"`tv_scale` | Controls the smoothness of the final output. | 150\n", |
|
"`range_scale` | Controls how far out of range RGB values are allowed to be. | 150\n", |
|
"`sat_scale` | Controls how much saturation is allowed. From nshepperd's JAX notebook. | 0\n", |
|
"`cutn` | Controls how many crops to take from the image. | 16\n", |
|
"`cutn_batches` | Accumulate CLIP gradient from multiple batches of cuts | 2\n", |
|
"**Init settings:**\n", |
|
"`init_image` | URL or local path | None\n", |
|
"`init_scale` | This enhances the effect of the init image, a good value is 1000 | 0\n", |
|
"`skip_steps Controls the starting point along the diffusion timesteps | 0\n", |
|
"`perlin_init` | Option to start with random perlin noise | False\n", |
|
"`perlin_mode` | ('gray', 'color') | 'mixed'\n", |
|
"**Advanced:**\n", |
|
"`skip_augs` |Controls whether to skip torchvision augmentations | False\n", |
|
"`randomize_class` |Controls whether the imagenet class is randomly changed each iteration | True\n", |
|
"`clip_denoised` |Determines whether CLIP discriminates a noisy or denoised image | False\n", |
|
"`clamp_grad` |Experimental: Using adaptive clip grad in the cond_fn | True\n", |
|
"`seed` | Choose a random seed and print it at end of run for reproduction | random_seed\n", |
|
"`fuzzy_prompt` | Controls whether to add multiple noisy prompts to the prompt losses | False\n", |
|
"`rand_mag` |Controls the magnitude of the random noise | 0.1\n", |
|
"`eta` | DDIM hyperparameter | 0.5\n", |
|
"\n", |
|
"..\n", |
|
"\n", |
|
"**Model settings**\n", |
|
"---\n", |
|
"\n", |
|
"Setting | Description | Default\n", |
|
"--- | --- | ---\n", |
|
"**Diffusion:**\n", |
|
"`timestep_respacing` | Modify this value to decrease the number of timesteps. | ddim100\n", |
|
"`diffusion_steps` || 1000\n", |
|
"**Diffusion:**\n", |
|
"`clip_models` | Models of CLIP to load. Typically the more, the better but they all come at a hefty VRAM cost. | ViT-B/32, ViT-B/16, RN50x4" |
|
] |
|
}, |
|
{ |
|
"cell_type": "markdown", |
|
"metadata": {}, |
|
"source": [ |
|
"# 1. Set Up" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"metadata": { |
|
"id": "_9Eg9Kf5FlfK" |
|
}, |
|
"source": [ |
|
"#@title 1.1 Check GPU Status\n", |
|
"import subprocess\n", |
|
"simple_nvidia_smi_display = False#@param {type:\"boolean\"}\n", |
|
"if simple_nvidia_smi_display:\n", |
|
" #!nvidia-smi\n", |
|
" nvidiasmi_output = subprocess.run(['nvidia-smi', '-L'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", |
|
" print(nvidiasmi_output)\n", |
|
"else:\n", |
|
" #!nvidia-smi -i 0 -e 0\n", |
|
" nvidiasmi_output = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", |
|
" print(nvidiasmi_output)\n", |
|
" nvidiasmi_ecc_note = subprocess.run(['nvidia-smi', '-i', '0', '-e', '0'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", |
|
" print(nvidiasmi_ecc_note)" |
|
], |
|
"outputs": [], |
|
"execution_count": null |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"metadata": {}, |
|
"source": [ |
|
"#@title 1.2 Prepare Folders\n", |
|
"import subprocess\n", |
|
"import sys\n", |
|
"import ipykernel\n", |
|
"\n", |
|
"def gitclone(url):\n", |
|
" res = subprocess.run(['git', 'clone', url], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", |
|
" print(res)\n", |
|
"\n", |
|
"def pipi(modulestr):\n", |
|
" res = subprocess.run(['pip', 'install', modulestr], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", |
|
" print(res)\n", |
|
"\n", |
|
"def pipie(modulestr):\n", |
|
" res = subprocess.run(['git', 'install', '-e', modulestr], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", |
|
" print(res)\n", |
|
"\n", |
|
"def wget(url, outputdir):\n", |
|
" res = subprocess.run(['wget', url, '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", |
|
" print(res)\n", |
|
"\n", |
|
"try:\n", |
|
" from google.colab import drive\n", |
|
" print(\"Google Colab detected. Using Google Drive.\")\n", |
|
" is_colab = True\n", |
|
" #@markdown If you connect your Google Drive, you can save the final image of each run on your drive.\n", |
|
" google_drive = True #@param {type:\"boolean\"}\n", |
|
" #@markdown Click here if you'd like to save the diffusion model checkpoint file to (and/or load from) your Google Drive:\n", |
|
" save_models_to_google_drive = True #@param {type:\"boolean\"}\n", |
|
"except:\n", |
|
" is_colab = False\n", |
|
" google_drive = False\n", |
|
" save_models_to_google_drive = False\n", |
|
" print(\"Google Colab not detected.\")\n", |
|
"\n", |
|
"if is_colab:\n", |
|
" if google_drive is True:\n", |
|
" drive.mount('/content/drive')\n", |
|
" root_path = '/content/drive/MyDrive/AI/Disco_Diffusion'\n", |
|
" else:\n", |
|
" root_path = '/content'\n", |
|
"else:\n", |
|
" root_path = '.'\n", |
|
"\n", |
|
"import os\n", |
|
"from os import path\n", |
|
"#Simple create paths taken with modifications from Datamosh's Batch VQGAN+CLIP notebook\n", |
|
"def createPath(filepath):\n", |
|
" if path.exists(filepath) == False:\n", |
|
" os.makedirs(filepath)\n", |
|
" print(f'Made {filepath}')\n", |
|
" else:\n", |
|
" print(f'filepath {filepath} exists.')\n", |
|
"\n", |
|
"initDirPath = f'{root_path}/init_images'\n", |
|
"createPath(initDirPath)\n", |
|
"outDirPath = f'{root_path}/images_out'\n", |
|
"createPath(outDirPath)\n", |
|
"\n", |
|
"if is_colab:\n", |
|
" if google_drive and not save_models_to_google_drive or not google_drive:\n", |
|
" model_path = '/content/model'\n", |
|
" createPath(model_path)\n", |
|
" if google_drive and save_models_to_google_drive:\n", |
|
" model_path = f'{root_path}/model'\n", |
|
" createPath(model_path)\n", |
|
"else:\n", |
|
" model_path = f'{root_path}/model'\n", |
|
" createPath(model_path)\n", |
|
"\n", |
|
"# libraries = f'{root_path}/libraries'\n", |
|
"# createPath(libraries)" |
|
], |
|
"outputs": [], |
|
"execution_count": null |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"metadata": {}, |
|
"source": [ |
|
"#@title ### 1.3 Install and import dependencies\n", |
|
"\n", |
|
"from os.path import exists as path_exists\n", |
|
"import pathlib, shutil\n", |
|
"\n", |
|
"if not is_colab:\n", |
|
" # If running locally, there's a good chance your env will need this in order to not crash upon np.matmul() or similar operations.\n", |
|
" os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'\n", |
|
"\n", |
|
"PROJECT_DIR = os.path.abspath(os.getcwd())\n", |
|
"USE_ADABINS = True\n", |
|
"\n", |
|
"if is_colab:\n", |
|
" if google_drive is not True:\n", |
|
" root_path = f'/content'\n", |
|
" model_path = '/content/models' \n", |
|
"else:\n", |
|
" root_path = f'.'\n", |
|
" model_path = f'{root_path}/model'\n", |
|
"\n", |
|
"model_256_downloaded = False\n", |
|
"model_512_downloaded = False\n", |
|
"model_secondary_downloaded = False\n", |
|
"\n", |
|
"if is_colab:\n", |
|
" gitclone(\"https://github.com/openai/CLIP\")\n", |
|
" #gitclone(\"https://github.com/facebookresearch/SLIP.git\")\n", |
|
" gitclone(\"https://github.com/crowsonkb/guided-diffusion\")\n", |
|
" gitclone(\"https://github.com/assafshocher/ResizeRight.git\")\n", |
|
" pipie(\"./CLIP\")\n", |
|
" pipie(\"./guided-diffusion\")\n", |
|
" multipip_res = subprocess.run(['pip', 'install', 'lpips', 'datetime', 'timm', 'ftfy'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", |
|
" print(multipip_res)\n", |
|
" subprocess.run(['apt', 'install', 'imagemagick'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", |
|
" gitclone(\"https://github.com/isl-org/MiDaS.git\")\n", |
|
" gitclone(\"https://github.com/alembics/disco-diffusion.git\")\n", |
|
" pipi(\"pytorch-lightning\")\n", |
|
" pipi(\"omegaconf\")\n", |
|
" pipi(\"einops\")\n", |
|
" # Rename a file to avoid a name conflict..\n", |
|
" try:\n", |
|
" os.rename(\"MiDaS/utils.py\", \"MiDaS/midas_utils.py\")\n", |
|
" shutil.copyfile(\"disco-diffusion/disco_xform_utils.py\", \"disco_xform_utils.py\")\n", |
|
" except:\n", |
|
" pass\n", |
|
"\n", |
|
"if not path_exists(f'{model_path}'):\n", |
|
" pathlib.Path(model_path).mkdir(parents=True, exist_ok=True)\n", |
|
"if not path_exists(f'{model_path}/dpt_large-midas-2f21e586.pt'):\n", |
|
" wget(\"https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt\", model_path)\n", |
|
"\n", |
|
"import sys\n", |
|
"import torch\n", |
|
"\n", |
|
"#Install pytorch3d\n", |
|
"if is_colab:\n", |
|
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n", |
|
" version_str=\"\".join([\n", |
|
" f\"py3{sys.version_info.minor}_cu\",\n", |
|
" torch.version.cuda.replace(\".\",\"\"),\n", |
|
" f\"_pyt{pyt_version_str}\"\n", |
|
" ])\n", |
|
" multipip_res = subprocess.run(['pip', 'install', 'fvcore', 'iopath'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", |
|
" print(multipip_res)\n", |
|
" subprocess.run(['pip', 'install', '--no-index', '--no-cache-dir', 'pytorch3d', '-f', f'https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", |
|
"\n", |
|
"# sys.path.append('./SLIP')\n", |
|
"sys.path.append('./ResizeRight')\n", |
|
"sys.path.append('./MiDaS')\n", |
|
"from dataclasses import dataclass\n", |
|
"from functools import partial\n", |
|
"import cv2\n", |
|
"import pandas as pd\n", |
|
"import gc\n", |
|
"import io\n", |
|
"import math\n", |
|
"import timm\n", |
|
"from IPython import display\n", |
|
"import lpips\n", |
|
"from PIL import Image, ImageOps\n", |
|
"import requests\n", |
|
"from glob import glob\n", |
|
"import json\n", |
|
"from types import SimpleNamespace\n", |
|
"from torch import nn\n", |
|
"from torch.nn import functional as F\n", |
|
"import torchvision.transforms as T\n", |
|
"import torchvision.transforms.functional as TF\n", |
|
"from tqdm.notebook import tqdm\n", |
|
"sys.path.append('./CLIP')\n", |
|
"sys.path.append('./guided-diffusion')\n", |
|
"import clip\n", |
|
"from resize_right import resize\n", |
|
"# from models import SLIP_VITB16, SLIP, SLIP_VITL16\n", |
|
"from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults\n", |
|
"from datetime import datetime\n", |
|
"import numpy as np\n", |
|
"import matplotlib.pyplot as plt\n", |
|
"import random\n", |
|
"from ipywidgets import Output\n", |
|
"import hashlib\n", |
|
"\n", |
|
"#SuperRes\n", |
|
"if is_colab:\n", |
|
" gitclone(\"https://github.com/CompVis/latent-diffusion.git\")\n", |
|
" gitclone(\"https://github.com/CompVis/taming-transformers\")\n", |
|
" pipie(\"./taming-transformers\")\n", |
|
" pipi(\"ipywidgets omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops wandb\")\n", |
|
"\n", |
|
"#SuperRes\n", |
|
"import ipywidgets as widgets\n", |
|
"import os\n", |
|
"sys.path.append(\".\")\n", |
|
"sys.path.append('./taming-transformers')\n", |
|
"from taming.models import vqgan # checking correct import from taming\n", |
|
"from torchvision.datasets.utils import download_url\n", |
|
"\n", |
|
"if is_colab:\n", |
|
" os.chdir('/content/latent-diffusion')\n", |
|
"else:\n", |
|
" #os.chdir('latent-diffusion')\n", |
|
" sys.path.append('latent-diffusion')\n", |
|
"from functools import partial\n", |
|
"from ldm.util import instantiate_from_config\n", |
|
"from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like\n", |
|
"# from ldm.models.diffusion.ddim import DDIMSampler\n", |
|
"from ldm.util import ismap\n", |
|
"if is_colab:\n", |
|
" os.chdir('/content')\n", |
|
" from google.colab import files\n", |
|
"else:\n", |
|
" os.chdir(f'{PROJECT_DIR}')\n", |
|
"from IPython.display import Image as ipyimg\n", |
|
"from numpy import asarray\n", |
|
"from einops import rearrange, repeat\n", |
|
"import torch, torchvision\n", |
|
"import time\n", |
|
"from omegaconf import OmegaConf\n", |
|
"import warnings\n", |
|
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n", |
|
"\n", |
|
"# AdaBins stuff\n", |
|
"if USE_ADABINS:\n", |
|
" if is_colab:\n", |
|
" gitclone(\"https://github.com/shariqfarooq123/AdaBins.git\")\n", |
|
" if not path_exists(f'{model_path}/AdaBins_nyu.pt'):\n", |
|
" wget(\"https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt\", model_path)\n", |
|
" pathlib.Path(\"pretrained\").mkdir(parents=True, exist_ok=True)\n", |
|
" shutil.copyfile(f\"{model_path}/AdaBins_nyu.pt\", \"pretrained/AdaBins_nyu.pt\")\n", |
|
" sys.path.append('./AdaBins')\n", |
|
" from infer import InferenceHelper\n", |
|
" MAX_ADABINS_AREA = 500000\n", |
|
"\n", |
|
"import torch\n", |
|
"DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", |
|
"print('Using device:', DEVICE)\n", |
|
"device = DEVICE # At least one of the modules expects this name..\n", |
|
"\n", |
|
"if torch.cuda.get_device_capability(DEVICE) == (8,0): ## A100 fix thanks to Emad\n", |
|
" print('Disabling CUDNN for A100 gpu', file=sys.stderr)\n", |
|
" torch.backends.cudnn.enabled = False" |
|
], |
|
"outputs": [], |
|
"execution_count": null |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"metadata": {}, |
|
"source": [ |
|
"#@title ### 1.4 Define Midas functions\n", |
|
"\n", |
|
"from midas.dpt_depth import DPTDepthModel\n", |
|
"from midas.midas_net import MidasNet\n", |
|
"from midas.midas_net_custom import MidasNet_small\n", |
|
"from midas.transforms import Resize, NormalizeImage, PrepareForNet\n", |
|
"\n", |
|
"# Initialize MiDaS depth model.\n", |
|
"# It remains resident in VRAM and likely takes around 2GB VRAM.\n", |
|
"# You could instead initialize it for each frame (and free it after each frame) to save VRAM.. but initializing it is slow.\n", |
|
"default_models = {\n", |
|
" \"midas_v21_small\": f\"{model_path}/midas_v21_small-70d6b9c8.pt\",\n", |
|
" \"midas_v21\": f\"{model_path}/midas_v21-f6b98070.pt\",\n", |
|
" \"dpt_large\": f\"{model_path}/dpt_large-midas-2f21e586.pt\",\n", |
|
" \"dpt_hybrid\": f\"{model_path}/dpt_hybrid-midas-501f0c75.pt\",\n", |
|
" \"dpt_hybrid_nyu\": f\"{model_path}/dpt_hybrid_nyu-2ce69ec7.pt\",}\n", |
|
"\n", |
|
"\n", |
|
"def init_midas_depth_model(midas_model_type=\"dpt_large\", optimize=True):\n", |
|
" midas_model = None\n", |
|
" net_w = None\n", |
|
" net_h = None\n", |
|
" resize_mode = None\n", |
|
" normalization = None\n", |
|
"\n", |
|
" print(f\"Initializing MiDaS '{midas_model_type}' depth model...\")\n", |
|
" # load network\n", |
|
" midas_model_path = default_models[midas_model_type]\n", |
|
"\n", |
|
" if midas_model_type == \"dpt_large\": # DPT-Large\n", |
|
" midas_model = DPTDepthModel(\n", |
|
" path=midas_model_path,\n", |
|
" backbone=\"vitl16_384\",\n", |
|
" non_negative=True,\n", |
|
" )\n", |
|
" net_w, net_h = 384, 384\n", |
|
" resize_mode = \"minimal\"\n", |
|
" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n", |
|
" elif midas_model_type == \"dpt_hybrid\": #DPT-Hybrid\n", |
|
" midas_model = DPTDepthModel(\n", |
|
" path=midas_model_path,\n", |
|
" backbone=\"vitb_rn50_384\",\n", |
|
" non_negative=True,\n", |
|
" )\n", |
|
" net_w, net_h = 384, 384\n", |
|
" resize_mode=\"minimal\"\n", |
|
" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n", |
|
" elif midas_model_type == \"dpt_hybrid_nyu\": #DPT-Hybrid-NYU\n", |
|
" midas_model = DPTDepthModel(\n", |
|
" path=midas_model_path,\n", |
|
" backbone=\"vitb_rn50_384\",\n", |
|
" non_negative=True,\n", |
|
" )\n", |
|
" net_w, net_h = 384, 384\n", |
|
" resize_mode=\"minimal\"\n", |
|
" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n", |
|
" elif midas_model_type == \"midas_v21\":\n", |
|
" midas_model = MidasNet(midas_model_path, non_negative=True)\n", |
|
" net_w, net_h = 384, 384\n", |
|
" resize_mode=\"upper_bound\"\n", |
|
" normalization = NormalizeImage(\n", |
|
" mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n", |
|
" )\n", |
|
" elif midas_model_type == \"midas_v21_small\":\n", |
|
" midas_model = MidasNet_small(midas_model_path, features=64, backbone=\"efficientnet_lite3\", exportable=True, non_negative=True, blocks={'expand': True})\n", |
|
" net_w, net_h = 256, 256\n", |
|
" resize_mode=\"upper_bound\"\n", |
|
" normalization = NormalizeImage(\n", |
|
" mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n", |
|
" )\n", |
|
" else:\n", |
|
" print(f\"midas_model_type '{midas_model_type}' not implemented\")\n", |
|
" assert False\n", |
|
"\n", |
|
" midas_transform = T.Compose(\n", |
|
" [\n", |
|
" Resize(\n", |
|
" net_w,\n", |
|
" net_h,\n", |
|
" resize_target=None,\n", |
|
" keep_aspect_ratio=True,\n", |
|
" ensure_multiple_of=32,\n", |
|
" resize_method=resize_mode,\n", |
|
" image_interpolation_method=cv2.INTER_CUBIC,\n", |
|
" ),\n", |
|
" normalization,\n", |
|
" PrepareForNet(),\n", |
|
" ]\n", |
|
" )\n", |
|
"\n", |
|
" midas_model.eval()\n", |
|
" \n", |
|
" if optimize==True:\n", |
|
" if DEVICE == torch.device(\"cuda\"):\n", |
|
" midas_model = midas_model.to(memory_format=torch.channels_last) \n", |
|
" midas_model = midas_model.half()\n", |
|
"\n", |
|
" midas_model.to(DEVICE)\n", |
|
"\n", |
|
" print(f\"MiDaS '{midas_model_type}' depth model initialized.\")\n", |
|
" return midas_model, midas_transform, net_w, net_h, resize_mode, normalization" |
|
], |
|
"outputs": [], |
|
"execution_count": null |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"metadata": {}, |
|
"source": [ |
|
"#@title 1.5 Define necessary functions\n", |
|
"\n", |
|
"# https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869\n", |
|
"\n", |
|
"import pytorch3d.transforms as p3dT\n", |
|
"import disco_xform_utils as dxf\n", |
|
"\n", |
|
"def interp(t):\n", |
|
" return 3 * t**2 - 2 * t ** 3\n", |
|
"\n", |
|
"def perlin(width, height, scale=10, device=None):\n", |
|
" gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device)\n", |
|
" xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device)\n", |
|
" ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device)\n", |
|
" wx = 1 - interp(xs)\n", |
|
" wy = 1 - interp(ys)\n", |
|
" dots = 0\n", |
|
" dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)\n", |
|
" dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)\n", |
|
" dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))\n", |
|
" dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys))\n", |
|
" return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale)\n", |
|
"\n", |
|
"def perlin_ms(octaves, width, height, grayscale, device=device):\n", |
|
" out_array = [0.5] if grayscale else [0.5, 0.5, 0.5]\n", |
|
" # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]\n", |
|
" for i in range(1 if grayscale else 3):\n", |
|
" scale = 2 ** len(octaves)\n", |
|
" oct_width = width\n", |
|
" oct_height = height\n", |
|
" for oct in octaves:\n", |
|
" p = perlin(oct_width, oct_height, scale, device)\n", |
|
" out_array[i] += p * oct\n", |
|
" scale //= 2\n", |
|
" oct_width *= 2\n", |
|
" oct_height *= 2\n", |
|
" return torch.cat(out_array)\n", |
|
"\n", |
|
"def create_perlin_noise(octaves=[1, 1, 1, 1], width=2, height=2, grayscale=True):\n", |
|
" out = perlin_ms(octaves, width, height, grayscale)\n", |
|
" if grayscale:\n", |
|
" out = TF.resize(size=(side_y, side_x), img=out.unsqueeze(0))\n", |
|
" out = TF.to_pil_image(out.clamp(0, 1)).convert('RGB')\n", |
|
" else:\n", |
|
" out = out.reshape(-1, 3, out.shape[0]//3, out.shape[1])\n", |
|
" out = TF.resize(size=(side_y, side_x), img=out)\n", |
|
" out = TF.to_pil_image(out.clamp(0, 1).squeeze())\n", |
|
"\n", |
|
" out = ImageOps.autocontrast(out)\n", |
|
" return out\n", |
|
"\n", |
|
"def regen_perlin():\n", |
|
" if perlin_mode == 'color':\n", |
|
" init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n", |
|
" init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)\n", |
|
" elif perlin_mode == 'gray':\n", |
|
" init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)\n", |
|
" init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n", |
|
" else:\n", |
|
" init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n", |
|
" init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n", |
|
"\n", |
|
" init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)\n", |
|
" del init2\n", |
|
" return init.expand(batch_size, -1, -1, -1)\n", |
|
"\n", |
|
"def fetch(url_or_path):\n", |
|
" if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):\n", |
|
" r = requests.get(url_or_path)\n", |
|
" r.raise_for_status()\n", |
|
" fd = io.BytesIO()\n", |
|
" fd.write(r.content)\n", |
|
" fd.seek(0)\n", |
|
" return fd\n", |
|
" return open(url_or_path, 'rb')\n", |
|
"\n", |
|
"def read_image_workaround(path):\n", |
|
" \"\"\"OpenCV reads images as BGR, Pillow saves them as RGB. Work around\n", |
|
" this incompatibility to avoid colour inversions.\"\"\"\n", |
|
" im_tmp = cv2.imread(path)\n", |
|
" return cv2.cvtColor(im_tmp, cv2.COLOR_BGR2RGB)\n", |
|
"\n", |
|
"def parse_prompt(prompt):\n", |
|
" if prompt.startswith('http://') or prompt.startswith('https://'):\n", |
|
" vals = prompt.rsplit(':', 2)\n", |
|
" vals = [vals[0] + ':' + vals[1], *vals[2:]]\n", |
|
" else:\n", |
|
" vals = prompt.rsplit(':', 1)\n", |
|
" vals = vals + ['', '1'][len(vals):]\n", |
|
" return vals[0], float(vals[1])\n", |
|
"\n", |
|
"def sinc(x):\n", |
|
" return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))\n", |
|
"\n", |
|
"def lanczos(x, a):\n", |
|
" cond = torch.logical_and(-a < x, x < a)\n", |
|
" out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))\n", |
|
" return out / out.sum()\n", |
|
"\n", |
|
"def ramp(ratio, width):\n", |
|
" n = math.ceil(width / ratio + 1)\n", |
|
" out = torch.empty([n])\n", |
|
" cur = 0\n", |
|
" for i in range(out.shape[0]):\n", |
|
" out[i] = cur\n", |
|
" cur += ratio\n", |
|
" return torch.cat([-out[1:].flip([0]), out])[1:-1]\n", |
|
"\n", |
|
"def resample(input, size, align_corners=True):\n", |
|
" n, c, h, w = input.shape\n", |
|
" dh, dw = size\n", |
|
"\n", |
|
" input = input.reshape([n * c, 1, h, w])\n", |
|
"\n", |
|
" if dh < h:\n", |
|
" kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)\n", |
|
" pad_h = (kernel_h.shape[0] - 1) // 2\n", |
|
" input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')\n", |
|
" input = F.conv2d(input, kernel_h[None, None, :, None])\n", |
|
"\n", |
|
" if dw < w:\n", |
|
" kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)\n", |
|
" pad_w = (kernel_w.shape[0] - 1) // 2\n", |
|
" input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')\n", |
|
" input = F.conv2d(input, kernel_w[None, None, None, :])\n", |
|
"\n", |
|
" input = input.reshape([n, c, h, w])\n", |
|
" return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)\n", |
|
"\n", |
|
"class MakeCutouts(nn.Module):\n", |
|
" def __init__(self, cut_size, cutn, skip_augs=False):\n", |
|
" super().__init__()\n", |
|
" self.cut_size = cut_size\n", |
|
" self.cutn = cutn\n", |
|
" self.skip_augs = skip_augs\n", |
|
" self.augs = T.Compose([\n", |
|
" T.RandomHorizontalFlip(p=0.5),\n", |
|
" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", |
|
" T.RandomAffine(degrees=15, translate=(0.1, 0.1)),\n", |
|
" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", |
|
" T.RandomPerspective(distortion_scale=0.4, p=0.7),\n", |
|
" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", |
|
" T.RandomGrayscale(p=0.15),\n", |
|
" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", |
|
" # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n", |
|
" ])\n", |
|
"\n", |
|
" def forward(self, input):\n", |
|
" input = T.Pad(input.shape[2]//4, fill=0)(input)\n", |
|
" sideY, sideX = input.shape[2:4]\n", |
|
" max_size = min(sideX, sideY)\n", |
|
"\n", |
|
" cutouts = []\n", |
|
" for ch in range(self.cutn):\n", |
|
" if ch > self.cutn - self.cutn//4:\n", |
|
" cutout = input.clone()\n", |
|
" else:\n", |
|
" size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))\n", |
|
" offsetx = torch.randint(0, abs(sideX - size + 1), ())\n", |
|
" offsety = torch.randint(0, abs(sideY - size + 1), ())\n", |
|
" cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n", |
|
"\n", |
|
" if not self.skip_augs:\n", |
|
" cutout = self.augs(cutout)\n", |
|
" cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))\n", |
|
" del cutout\n", |
|
"\n", |
|
" cutouts = torch.cat(cutouts, dim=0)\n", |
|
" return cutouts\n", |
|
"\n", |
|
"cutout_debug = False\n", |
|
"padargs = {}\n", |
|
"\n", |
|
"class MakeCutoutsDango(nn.Module):\n", |
|
" def __init__(self, cut_size,\n", |
|
" Overview=4, \n", |
|
" InnerCrop = 0, IC_Size_Pow=0.5, IC_Grey_P = 0.2\n", |
|
" ):\n", |
|
" super().__init__()\n", |
|
" self.cut_size = cut_size\n", |
|
" self.Overview = Overview\n", |
|
" self.InnerCrop = InnerCrop\n", |
|
" self.IC_Size_Pow = IC_Size_Pow\n", |
|
" self.IC_Grey_P = IC_Grey_P\n", |
|
" if args.animation_mode == 'None':\n", |
|
" self.augs = T.Compose([\n", |
|
" T.RandomHorizontalFlip(p=0.5),\n", |
|
" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", |
|
" T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR),\n", |
|
" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", |
|
" T.RandomGrayscale(p=0.1),\n", |
|
" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", |
|
" T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n", |
|
" ])\n", |
|
" elif args.animation_mode == 'Video Input':\n", |
|
" self.augs = T.Compose([\n", |
|
" T.RandomHorizontalFlip(p=0.5),\n", |
|
" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", |
|
" T.RandomAffine(degrees=15, translate=(0.1, 0.1)),\n", |
|
" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", |
|
" T.RandomPerspective(distortion_scale=0.4, p=0.7),\n", |
|
" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", |
|
" T.RandomGrayscale(p=0.15),\n", |
|
" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", |
|
" # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n", |
|
" ])\n", |
|
" elif args.animation_mode == '2D' or args.animation_mode == '3D':\n", |
|
" self.augs = T.Compose([\n", |
|
" T.RandomHorizontalFlip(p=0.4),\n", |
|
" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", |
|
" T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR),\n", |
|
" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", |
|
" T.RandomGrayscale(p=0.1),\n", |
|
" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", |
|
" T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.3),\n", |
|
" ])\n", |
|
" \n", |
|
"\n", |
|
" def forward(self, input):\n", |
|
" cutouts = []\n", |
|
" gray = T.Grayscale(3)\n", |
|
" sideY, sideX = input.shape[2:4]\n", |
|
" max_size = min(sideX, sideY)\n", |
|
" min_size = min(sideX, sideY, self.cut_size)\n", |
|
" l_size = max(sideX, sideY)\n", |
|
" output_shape = [1,3,self.cut_size,self.cut_size] \n", |
|
" output_shape_2 = [1,3,self.cut_size+2,self.cut_size+2]\n", |
|
" pad_input = F.pad(input,((sideY-max_size)//2,(sideY-max_size)//2,(sideX-max_size)//2,(sideX-max_size)//2), **padargs)\n", |
|
" cutout = resize(pad_input, out_shape=output_shape)\n", |
|
"\n", |
|
" if self.Overview>0:\n", |
|
" if self.Overview<=4:\n", |
|
" if self.Overview>=1:\n", |
|
" cutouts.append(cutout)\n", |
|
" if self.Overview>=2:\n", |
|
" cutouts.append(gray(cutout))\n", |
|
" if self.Overview>=3:\n", |
|
" cutouts.append(TF.hflip(cutout))\n", |
|
" if self.Overview==4:\n", |
|
" cutouts.append(gray(TF.hflip(cutout)))\n", |
|
" else:\n", |
|
" cutout = resize(pad_input, out_shape=output_shape)\n", |
|
" for _ in range(self.Overview):\n", |
|
" cutouts.append(cutout)\n", |
|
"\n", |
|
" if cutout_debug:\n", |
|
" if is_colab:\n", |
|
" TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save(\"/content/cutout_overview0.jpg\",quality=99)\n", |
|
" else:\n", |
|
" TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save(\"cutout_overview0.jpg\",quality=99)\n", |
|
"\n", |
|
" \n", |
|
" if self.InnerCrop >0:\n", |
|
" for i in range(self.InnerCrop):\n", |
|
" size = int(torch.rand([])**self.IC_Size_Pow * (max_size - min_size) + min_size)\n", |
|
" offsetx = torch.randint(0, sideX - size + 1, ())\n", |
|
" offsety = torch.randint(0, sideY - size + 1, ())\n", |
|
" cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n", |
|
" if i <= int(self.IC_Grey_P * self.InnerCrop):\n", |
|
" cutout = gray(cutout)\n", |
|
" cutout = resize(cutout, out_shape=output_shape)\n", |
|
" cutouts.append(cutout)\n", |
|
" if cutout_debug:\n", |
|
" if is_colab:\n", |
|
" TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save(\"/content/cutout_InnerCrop.jpg\",quality=99)\n", |
|
" else:\n", |
|
" TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save(\"cutout_InnerCrop.jpg\",quality=99)\n", |
|
" cutouts = torch.cat(cutouts)\n", |
|
" if skip_augs is not True: cutouts=self.augs(cutouts)\n", |
|
" return cutouts\n", |
|
"\n", |
|
"def spherical_dist_loss(x, y):\n", |
|
" x = F.normalize(x, dim=-1)\n", |
|
" y = F.normalize(y, dim=-1)\n", |
|
" return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) \n", |
|
"\n", |
|
"def tv_loss(input):\n", |
|
" \"\"\"L2 total variation loss, as in Mahendran et al.\"\"\"\n", |
|
" input = F.pad(input, (0, 1, 0, 1), 'replicate')\n", |
|
" x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]\n", |
|
" y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]\n", |
|
" return (x_diff**2 + y_diff**2).mean([1, 2, 3])\n", |
|
"\n", |
|
"\n", |
|
"def range_loss(input):\n", |
|
" return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])\n", |
|
"\n", |
|
"stop_on_next_loop = False # Make sure GPU memory doesn't get corrupted from cancelling the run mid-way through, allow a full frame to complete\n", |
|
"\n", |
|
"def do_run():\n", |
|
" seed = args.seed\n", |
|
" print(range(args.start_frame, args.max_frames))\n", |
|
"\n", |
|
" if (args.animation_mode == \"3D\") and (args.midas_weight > 0.0):\n", |
|
" midas_model, midas_transform, midas_net_w, midas_net_h, midas_resize_mode, midas_normalization = init_midas_depth_model(args.midas_depth_model)\n", |
|
" for frame_num in range(args.start_frame, args.max_frames):\n", |
|
" if stop_on_next_loop:\n", |
|
" break\n", |
|
" \n", |
|
" display.clear_output(wait=True)\n", |
|
"\n", |
|
" # Print Frame progress if animation mode is on\n", |
|
" if args.animation_mode != \"None\":\n", |
|
" batchBar = tqdm(range(args.max_frames), desc =\"Frames\")\n", |
|
" batchBar.n = frame_num\n", |
|
" batchBar.refresh()\n", |
|
"\n", |
|
" \n", |
|
" # Inits if not video frames\n", |
|
" if args.animation_mode != \"Video Input\":\n", |
|
" if args.init_image == '':\n", |
|
" init_image = None\n", |
|
" else:\n", |
|
" init_image = args.init_image\n", |
|
" init_scale = args.init_scale\n", |
|
" skip_steps = args.skip_steps\n", |
|
"\n", |
|
" if args.animation_mode == \"2D\":\n", |
|
" if args.key_frames:\n", |
|
" angle = args.angle_series[frame_num]\n", |
|
" zoom = args.zoom_series[frame_num]\n", |
|
" translation_x = args.translation_x_series[frame_num]\n", |
|
" translation_y = args.translation_y_series[frame_num]\n", |
|
" print(\n", |
|
" f'angle: {angle}',\n", |
|
" f'zoom: {zoom}',\n", |
|
" f'translation_x: {translation_x}',\n", |
|
" f'translation_y: {translation_y}',\n", |
|
" )\n", |
|
" \n", |
|
" if frame_num > 0:\n", |
|
" seed = seed + 1 \n", |
|
" if resume_run and frame_num == start_frame:\n", |
|
" img_0 = cv2.imread(batchFolder+f\"/{batch_name}({batchNum})_{start_frame-1:04}.png\")\n", |
|
" else:\n", |
|
" img_0 = cv2.imread('prevFrame.png')\n", |
|
" center = (1*img_0.shape[1]//2, 1*img_0.shape[0]//2)\n", |
|
" trans_mat = np.float32(\n", |
|
" [[1, 0, translation_x],\n", |
|
" [0, 1, translation_y]]\n", |
|
" )\n", |
|
" rot_mat = cv2.getRotationMatrix2D( center, angle, zoom )\n", |
|
" trans_mat = np.vstack([trans_mat, [0,0,1]])\n", |
|
" rot_mat = np.vstack([rot_mat, [0,0,1]])\n", |
|
" transformation_matrix = np.matmul(rot_mat, trans_mat)\n", |
|
" img_0 = cv2.warpPerspective(\n", |
|
" img_0,\n", |
|
" transformation_matrix,\n", |
|
" (img_0.shape[1], img_0.shape[0]),\n", |
|
" borderMode=cv2.BORDER_WRAP\n", |
|
" )\n", |
|
"\n", |
|
" cv2.imwrite('prevFrameScaled.png', img_0)\n", |
|
" init_image = 'prevFrameScaled.png'\n", |
|
" init_scale = args.frames_scale\n", |
|
" skip_steps = args.calc_frames_skip_steps\n", |
|
"\n", |
|
" if args.animation_mode == \"3D\":\n", |
|
" if args.key_frames:\n", |
|
" angle = args.angle_series[frame_num]\n", |
|
" #zoom = args.zoom_series[frame_num]\n", |
|
" translation_x = args.translation_x_series[frame_num]\n", |
|
" translation_y = args.translation_y_series[frame_num]\n", |
|
" translation_z = args.translation_z_series[frame_num]\n", |
|
" rotation_3d_x = args.rotation_3d_x_series[frame_num]\n", |
|
" rotation_3d_y = args.rotation_3d_y_series[frame_num]\n", |
|
" rotation_3d_z = args.rotation_3d_z_series[frame_num]\n", |
|
" print(\n", |
|
" f'angle: {angle}',\n", |
|
" #f'zoom: {zoom}',\n", |
|
" f'translation_x: {translation_x}',\n", |
|
" f'translation_y: {translation_y}',\n", |
|
" f'translation_z: {translation_z}',\n", |
|
" f'rotation_3d_x: {rotation_3d_x}',\n", |
|
" f'rotation_3d_y: {rotation_3d_y}',\n", |
|
" f'rotation_3d_z: {rotation_3d_z}',\n", |
|
" )\n", |
|
"\n", |
|
" if frame_num > 0:\n", |
|
" seed = seed + 1 \n", |
|
" if resume_run and frame_num == start_frame:\n", |
|
" img_filepath = batchFolder+f\"/{batch_name}({batchNum})_{start_frame-1:04}.png\"\n", |
|
" else:\n", |
|
" img_filepath = '/content/prevFrame.png' if is_colab else 'prevFrame.png'\n", |
|
" trans_scale = 1.0/200.0\n", |
|
" translate_xyz = [-translation_x*trans_scale, translation_y*trans_scale, -translation_z*trans_scale]\n", |
|
" rotate_xyz = [rotation_3d_x, rotation_3d_y, rotation_3d_z]\n", |
|
" print('translation:',translate_xyz)\n", |
|
" print('rotation:',rotate_xyz)\n", |
|
" rot_mat = p3dT.euler_angles_to_matrix(torch.tensor(rotate_xyz, device=device), \"XYZ\").unsqueeze(0)\n", |
|
" print(\"rot_mat: \" + str(rot_mat))\n", |
|
" next_step_pil = dxf.transform_image_3d(img_filepath, midas_model, midas_transform, DEVICE,\n", |
|
" rot_mat, translate_xyz, args.near_plane, args.far_plane,\n", |
|
" args.fov, padding_mode=args.padding_mode,\n", |
|
" sampling_mode=args.sampling_mode, midas_weight=args.midas_weight)\n", |
|
" next_step_pil.save('prevFrameScaled.png')\n", |
|
" init_image = 'prevFrameScaled.png'\n", |
|
" init_scale = args.frames_scale\n", |
|
" skip_steps = args.calc_frames_skip_steps\n", |
|
"\n", |
|
" if args.animation_mode == \"Video Input\":\n", |
|
" seed = seed + 1 \n", |
|
" init_image = f'{videoFramesFolder}/{frame_num+1:04}.jpg'\n", |
|
" init_scale = args.frames_scale\n", |
|
" skip_steps = args.calc_frames_skip_steps\n", |
|
"\n", |
|
" loss_values = []\n", |
|
" \n", |
|
" if seed is not None:\n", |
|
" np.random.seed(seed)\n", |
|
" random.seed(seed)\n", |
|
" torch.manual_seed(seed)\n", |
|
" torch.cuda.manual_seed_all(seed)\n", |
|
" torch.backends.cudnn.deterministic = True\n", |
|
" \n", |
|
" target_embeds, weights = [], []\n", |
|
" \n", |
|
" if args.prompts_series is not None and frame_num >= len(args.prompts_series):\n", |
|
" frame_prompt = args.prompts_series[-1]\n", |
|
" elif args.prompts_series is not None:\n", |
|
" frame_prompt = args.prompts_series[frame_num]\n", |
|
" else:\n", |
|
" frame_prompt = []\n", |
|
" \n", |
|
" print(args.image_prompts_series)\n", |
|
" if args.image_prompts_series is not None and frame_num >= len(args.image_prompts_series):\n", |
|
" image_prompt = args.image_prompts_series[-1]\n", |
|
" elif args.image_prompts_series is not None:\n", |
|
" image_prompt = args.image_prompts_series[frame_num]\n", |
|
" else:\n", |
|
" image_prompt = []\n", |
|
"\n", |
|
" print(f'Frame Prompt: {frame_prompt}')\n", |
|
"\n", |
|
" model_stats = []\n", |
|
" for clip_model in clip_models:\n", |
|
" cutn = 16\n", |
|
" model_stat = {\"clip_model\":None,\"target_embeds\":[],\"make_cutouts\":None,\"weights\":[]}\n", |
|
" model_stat[\"clip_model\"] = clip_model\n", |
|
" \n", |
|
" \n", |
|
" for prompt in frame_prompt:\n", |
|
" txt, weight = parse_prompt(prompt)\n", |
|
" txt = clip_model.encode_text(clip.tokenize(prompt).to(device)).float()\n", |
|
" \n", |
|
" if args.fuzzy_prompt:\n", |
|
" for i in range(25):\n", |
|
" model_stat[\"target_embeds\"].append((txt + torch.randn(txt.shape).cuda() * args.rand_mag).clamp(0,1))\n", |
|
" model_stat[\"weights\"].append(weight)\n", |
|
" else:\n", |
|
" model_stat[\"target_embeds\"].append(txt)\n", |
|
" model_stat[\"weights\"].append(weight)\n", |
|
" \n", |
|
" if image_prompt:\n", |
|
" model_stat[\"make_cutouts\"] = MakeCutouts(clip_model.visual.input_resolution, cutn, skip_augs=skip_augs) \n", |
|
" for prompt in image_prompt:\n", |
|
" path, weight = parse_prompt(prompt)\n", |
|
" img = Image.open(fetch(path)).convert('RGB')\n", |
|
" img = TF.resize(img, min(side_x, side_y, *img.size), T.InterpolationMode.LANCZOS)\n", |
|
" batch = model_stat[\"make_cutouts\"](TF.to_tensor(img).to(device).unsqueeze(0).mul(2).sub(1))\n", |
|
" embed = clip_model.encode_image(normalize(batch)).float()\n", |
|
" if fuzzy_prompt:\n", |
|
" for i in range(25):\n", |
|
" model_stat[\"target_embeds\"].append((embed + torch.randn(embed.shape).cuda() * rand_mag).clamp(0,1))\n", |
|
" weights.extend([weight / cutn] * cutn)\n", |
|
" else:\n", |
|
" model_stat[\"target_embeds\"].append(embed)\n", |
|
" model_stat[\"weights\"].extend([weight / cutn] * cutn)\n", |
|
" \n", |
|
" model_stat[\"target_embeds\"] = torch.cat(model_stat[\"target_embeds\"])\n", |
|
" model_stat[\"weights\"] = torch.tensor(model_stat[\"weights\"], device=device)\n", |
|
" if model_stat[\"weights\"].sum().abs() < 1e-3:\n", |
|
" raise RuntimeError('The weights must not sum to 0.')\n", |
|
" model_stat[\"weights\"] /= model_stat[\"weights\"].sum().abs()\n", |
|
" model_stats.append(model_stat)\n", |
|
" \n", |
|
" init = None\n", |
|
" if init_image is not None:\n", |
|
" init = Image.open(fetch(init_image)).convert('RGB')\n", |
|
" init = init.resize((args.side_x, args.side_y), Image.LANCZOS)\n", |
|
" init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)\n", |
|
" \n", |
|
" if args.perlin_init:\n", |
|
" if args.perlin_mode == 'color':\n", |
|
" init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n", |
|
" init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)\n", |
|
" elif args.perlin_mode == 'gray':\n", |
|
" init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)\n", |
|
" init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n", |
|
" else:\n", |
|
" init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n", |
|
" init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n", |
|
" # init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device)\n", |
|
" init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)\n", |
|
" del init2\n", |
|
" \n", |
|
" cur_t = None\n", |
|
" \n", |
|
" def cond_fn(x, t, y=None):\n", |
|
" with torch.enable_grad():\n", |
|
" x_is_NaN = False\n", |
|
" x = x.detach().requires_grad_()\n", |
|
" n = x.shape[0]\n", |
|
" if use_secondary_model is True:\n", |
|
" alpha = torch.tensor(diffusion.sqrt_alphas_cumprod[cur_t], device=device, dtype=torch.float32)\n", |
|
" sigma = torch.tensor(diffusion.sqrt_one_minus_alphas_cumprod[cur_t], device=device, dtype=torch.float32)\n", |
|
" cosine_t = alpha_sigma_to_t(alpha, sigma)\n", |
|
" out = secondary_model(x, cosine_t[None].repeat([n])).pred\n", |
|
" fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]\n", |
|
" x_in = out * fac + x * (1 - fac)\n", |
|
" x_in_grad = torch.zeros_like(x_in)\n", |
|
" else:\n", |
|
" my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t\n", |
|
" out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y})\n", |
|
" fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]\n", |
|
" x_in = out['pred_xstart'] * fac + x * (1 - fac)\n", |
|
" x_in_grad = torch.zeros_like(x_in)\n", |
|
" for model_stat in model_stats:\n", |
|
" for i in range(args.cutn_batches):\n", |
|
" t_int = int(t.item())+1 #errors on last step without +1, need to find source\n", |
|
" #when using SLIP Base model the dimensions need to be hard coded to avoid AttributeError: 'VisionTransformer' object has no attribute 'input_resolution'\n", |
|
" try:\n", |
|
" input_resolution=model_stat[\"clip_model\"].visual.input_resolution\n", |
|
" except:\n", |
|
" input_resolution=224\n", |
|
"\n", |
|
" cuts = MakeCutoutsDango(input_resolution,\n", |
|
" Overview= args.cut_overview[1000-t_int], \n", |
|
" InnerCrop = args.cut_innercut[1000-t_int], IC_Size_Pow=args.cut_ic_pow, IC_Grey_P = args.cut_icgray_p[1000-t_int]\n", |
|
" )\n", |
|
" clip_in = normalize(cuts(x_in.add(1).div(2)))\n", |
|
" image_embeds = model_stat[\"clip_model\"].encode_image(clip_in).float()\n", |
|
" dists = spherical_dist_loss(image_embeds.unsqueeze(1), model_stat[\"target_embeds\"].unsqueeze(0))\n", |
|
" dists = dists.view([args.cut_overview[1000-t_int]+args.cut_innercut[1000-t_int], n, -1])\n", |
|
" losses = dists.mul(model_stat[\"weights\"]).sum(2).mean(0)\n", |
|
" loss_values.append(losses.sum().item()) # log loss, probably shouldn't do per cutn_batch\n", |
|
" x_in_grad += torch.autograd.grad(losses.sum() * clip_guidance_scale, x_in)[0] / cutn_batches\n", |
|
" tv_losses = tv_loss(x_in)\n", |
|
" if use_secondary_model is True:\n", |
|
" range_losses = range_loss(out)\n", |
|
" else:\n", |
|
" range_losses = range_loss(out['pred_xstart'])\n", |
|
" sat_losses = torch.abs(x_in - x_in.clamp(min=-1,max=1)).mean()\n", |
|
" loss = tv_losses.sum() * tv_scale + range_losses.sum() * range_scale + sat_losses.sum() * sat_scale\n", |
|
" if init is not None and args.init_scale:\n", |
|
" init_losses = lpips_model(x_in, init)\n", |
|
" loss = loss + init_losses.sum() * args.init_scale\n", |
|
" x_in_grad += torch.autograd.grad(loss, x_in)[0]\n", |
|
" if torch.isnan(x_in_grad).any()==False:\n", |
|
" grad = -torch.autograd.grad(x_in, x, x_in_grad)[0]\n", |
|
" else:\n", |
|
" # print(\"NaN'd\")\n", |
|
" x_is_NaN = True\n", |
|
" grad = torch.zeros_like(x)\n", |
|
" if args.clamp_grad and x_is_NaN == False:\n", |
|
" magnitude = grad.square().mean().sqrt()\n", |
|
" return grad * magnitude.clamp(max=args.clamp_max) / magnitude #min=-0.02, min=-clamp_max, \n", |
|
" return grad\n", |
|
" \n", |
|
" if args.sampling_mode == 'ddim':\n", |
|
" sample_fn = diffusion.ddim_sample_loop_progressive\n", |
|
" else:\n", |
|
" sample_fn = diffusion.plms_sample_loop_progressive\n", |
|
"\n", |
|
"\n", |
|
" image_display = Output()\n", |
|
" for i in range(args.n_batches):\n", |
|
" if args.animation_mode == 'None':\n", |
|
" display.clear_output(wait=True)\n", |
|
" batchBar = tqdm(range(args.n_batches), desc =\"Batches\")\n", |
|
" batchBar.n = i\n", |
|
" batchBar.refresh()\n", |
|
" print('')\n", |
|
" display.display(image_display)\n", |
|
" gc.collect()\n", |
|
" torch.cuda.empty_cache()\n", |
|
" cur_t = diffusion.num_timesteps - skip_steps - 1\n", |
|
" total_steps = cur_t\n", |
|
"\n", |
|
" if perlin_init:\n", |
|
" init = regen_perlin()\n", |
|
"\n", |
|
" if args.sampling_mode == 'ddim':\n", |
|
" samples = sample_fn(\n", |
|
" model,\n", |
|
" (batch_size, 3, args.side_y, args.side_x),\n", |
|
" clip_denoised=clip_denoised,\n", |
|
" model_kwargs={},\n", |
|
" cond_fn=cond_fn,\n", |
|
" progress=True,\n", |
|
" skip_timesteps=skip_steps,\n", |
|
" init_image=init,\n", |
|
" randomize_class=randomize_class,\n", |
|
" eta=eta,\n", |
|
" )\n", |
|
" else:\n", |
|
" samples = sample_fn(\n", |
|
" model,\n", |
|
" (batch_size, 3, args.side_y, args.side_x),\n", |
|
" clip_denoised=clip_denoised,\n", |
|
" model_kwargs={},\n", |
|
" cond_fn=cond_fn,\n", |
|
" progress=True,\n", |
|
" skip_timesteps=skip_steps,\n", |
|
" init_image=init,\n", |
|
" randomize_class=randomize_class,\n", |
|
" order=2,\n", |
|
" )\n", |
|
" \n", |
|
" \n", |
|
" # with run_display:\n", |
|
" # display.clear_output(wait=True)\n", |
|
" imgToSharpen = None\n", |
|
" for j, sample in enumerate(samples): \n", |
|
" cur_t -= 1\n", |
|
" intermediateStep = False\n", |
|
" if args.steps_per_checkpoint is not None:\n", |
|
" if j % steps_per_checkpoint == 0 and j > 0:\n", |
|
" intermediateStep = True\n", |
|
" elif j in args.intermediate_saves:\n", |
|
" intermediateStep = True\n", |
|
" with image_display:\n", |
|
" if j % args.display_rate == 0 or cur_t == -1 or intermediateStep == True:\n", |
|
" for k, image in enumerate(sample['pred_xstart']):\n", |
|
" # tqdm.write(f'Batch {i}, step {j}, output {k}:')\n", |
|
" current_time = datetime.now().strftime('%y%m%d-%H%M%S_%f')\n", |
|
" percent = math.ceil(j/total_steps*100)\n", |
|
" if args.n_batches > 0:\n", |
|
" #if intermediates are saved to the subfolder, don't append a step or percentage to the name\n", |
|
" if cur_t == -1 and args.intermediates_in_subfolder is True:\n", |
|
" save_num = f'{frame_num:04}' if animation_mode != \"None\" else i\n", |
|
" filename = f'{args.batch_name}({args.batchNum})_{save_num}.png'\n", |
|
" else:\n", |
|
" #If we're working with percentages, append it\n", |
|
" if args.steps_per_checkpoint is not None:\n", |
|
" filename = f'{args.batch_name}({args.batchNum})_{i:04}-{percent:02}%.png'\n", |
|
" # Or else, iIf we're working with specific steps, append those\n", |
|
" else:\n", |
|
" filename = f'{args.batch_name}({args.batchNum})_{i:04}-{j:03}.png'\n", |
|
" image = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))\n", |
|
" if j % args.display_rate == 0 or cur_t == -1:\n", |
|
" image.save('progress.png')\n", |
|
" display.clear_output(wait=True)\n", |
|
" display.display(display.Image('progress.png'))\n", |
|
" if args.steps_per_checkpoint is not None:\n", |
|
" if j % args.steps_per_checkpoint == 0 and j > 0:\n", |
|
" if args.intermediates_in_subfolder is True:\n", |
|
" image.save(f'{partialFolder}/{filename}')\n", |
|
" else:\n", |
|
" image.save(f'{batchFolder}/{filename}')\n", |
|
" else:\n", |
|
" if j in args.intermediate_saves:\n", |
|
" if args.intermediates_in_subfolder is True:\n", |
|
" image.save(f'{partialFolder}/{filename}')\n", |
|
" else:\n", |
|
" image.save(f'{batchFolder}/{filename}')\n", |
|
" if cur_t == -1:\n", |
|
" if frame_num == 0:\n", |
|
" save_settings()\n", |
|
" if args.animation_mode != \"None\":\n", |
|
" image.save('prevFrame.png')\n", |
|
" if args.sharpen_preset != \"Off\" and animation_mode == \"None\":\n", |
|
" imgToSharpen = image\n", |
|
" if args.keep_unsharp is True:\n", |
|
" image.save(f'{unsharpenFolder}/{filename}')\n", |
|
" else:\n", |
|
" image.save(f'{batchFolder}/{filename}')\n", |
|
" # if frame_num != args.max_frames-1:\n", |
|
" # display.clear_output()\n", |
|
"\n", |
|
" with image_display: \n", |
|
" if args.sharpen_preset != \"Off\" and animation_mode == \"None\":\n", |
|
" print('Starting Diffusion Sharpening...')\n", |
|
" do_superres(imgToSharpen, f'{batchFolder}/{filename}')\n", |
|
" display.clear_output()\n", |
|
" \n", |
|
" plt.plot(np.array(loss_values), 'r')\n", |
|
"\n", |
|
"def save_settings():\n", |
|
" setting_list = {\n", |
|
" 'text_prompts': text_prompts,\n", |
|
" 'image_prompts': image_prompts,\n", |
|
" 'clip_guidance_scale': clip_guidance_scale,\n", |
|
" 'tv_scale': tv_scale,\n", |
|
" 'range_scale': range_scale,\n", |
|
" 'sat_scale': sat_scale,\n", |
|
" # 'cutn': cutn,\n", |
|
" 'cutn_batches': cutn_batches,\n", |
|
" 'max_frames': max_frames,\n", |
|
" 'interp_spline': interp_spline,\n", |
|
" # 'rotation_per_frame': rotation_per_frame,\n", |
|
" 'init_image': init_image,\n", |
|
" 'init_scale': init_scale,\n", |
|
" 'skip_steps': skip_steps,\n", |
|
" # 'zoom_per_frame': zoom_per_frame,\n", |
|
" 'frames_scale': frames_scale,\n", |
|
" 'frames_skip_steps': frames_skip_steps,\n", |
|
" 'perlin_init': perlin_init,\n", |
|
" 'perlin_mode': perlin_mode,\n", |
|
" 'skip_augs': skip_augs,\n", |
|
" 'randomize_class': randomize_class,\n", |
|
" 'clip_denoised': clip_denoised,\n", |
|
" 'clamp_grad': clamp_grad,\n", |
|
" 'clamp_max': clamp_max,\n", |
|
" 'seed': seed,\n", |
|
" 'fuzzy_prompt': fuzzy_prompt,\n", |
|
" 'rand_mag': rand_mag,\n", |
|
" 'eta': eta,\n", |
|
" 'width': width_height[0],\n", |
|
" 'height': width_height[1],\n", |
|
" 'diffusion_model': diffusion_model,\n", |
|
" 'use_secondary_model': use_secondary_model,\n", |
|
" 'steps': steps,\n", |
|
" 'diffusion_steps': diffusion_steps,\n", |
|
" 'sampling_mode': sampling_mode,\n", |
|
" 'ViTB32': ViTB32,\n", |
|
" 'ViTB16': ViTB16,\n", |
|
" 'ViTL14': ViTL14,\n", |
|
" 'RN101': RN101,\n", |
|
" 'RN50': RN50,\n", |
|
" 'RN50x4': RN50x4,\n", |
|
" 'RN50x16': RN50x16,\n", |
|
" 'RN50x64': RN50x64,\n", |
|
" 'cut_overview': str(cut_overview),\n", |
|
" 'cut_innercut': str(cut_innercut),\n", |
|
" 'cut_ic_pow': cut_ic_pow,\n", |
|
" 'cut_icgray_p': str(cut_icgray_p),\n", |
|
" 'key_frames': key_frames,\n", |
|
" 'max_frames': max_frames,\n", |
|
" 'angle': angle,\n", |
|
" 'zoom': zoom,\n", |
|
" 'translation_x': translation_x,\n", |
|
" 'translation_y': translation_y,\n", |
|
" 'translation_z': translation_z,\n", |
|
" 'rotation_3d_x': rotation_3d_x,\n", |
|
" 'rotation_3d_y': rotation_3d_y,\n", |
|
" 'rotation_3d_z': rotation_3d_z,\n", |
|
" 'midas_depth_model': midas_depth_model,\n", |
|
" 'midas_weight': midas_weight,\n", |
|
" 'near_plane': near_plane,\n", |
|
" 'far_plane': far_plane,\n", |
|
" 'fov': fov,\n", |
|
" 'padding_mode': padding_mode,\n", |
|
" 'sampling_mode': sampling_mode,\n", |
|
" 'video_init_path':video_init_path,\n", |
|
" 'extract_nth_frame':extract_nth_frame,\n", |
|
" }\n", |
|
" # print('Settings:', setting_list)\n", |
|
" with open(f\"{batchFolder}/{batch_name}({batchNum})_settings.txt\", \"w+\") as f: #save settings\n", |
|
" json.dump(setting_list, f, ensure_ascii=False, indent=4)" |
|
], |
|
"outputs": [], |
|
"execution_count": null |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"metadata": {}, |
|
"source": [ |
|
"#@title 1.6 Define the secondary diffusion model\n", |
|
"\n", |
|
"def append_dims(x, n):\n", |
|
" return x[(Ellipsis, *(None,) * (n - x.ndim))]\n", |
|
"\n", |
|
"\n", |
|
"def expand_to_planes(x, shape):\n", |
|
" return append_dims(x, len(shape)).repeat([1, 1, *shape[2:]])\n", |
|
"\n", |
|
"\n", |
|
"def alpha_sigma_to_t(alpha, sigma):\n", |
|
" return torch.atan2(sigma, alpha) * 2 / math.pi\n", |
|
"\n", |
|
"\n", |
|
"def t_to_alpha_sigma(t):\n", |
|
" return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)\n", |
|
"\n", |
|
"\n", |
|
"@dataclass\n", |
|
"class DiffusionOutput:\n", |
|
" v: torch.Tensor\n", |
|
" pred: torch.Tensor\n", |
|
" eps: torch.Tensor\n", |
|
"\n", |
|
"\n", |
|
"class ConvBlock(nn.Sequential):\n", |
|
" def __init__(self, c_in, c_out):\n", |
|
" super().__init__(\n", |
|
" nn.Conv2d(c_in, c_out, 3, padding=1),\n", |
|
" nn.ReLU(inplace=True),\n", |
|
" )\n", |
|
"\n", |
|
"\n", |
|
"class SkipBlock(nn.Module):\n", |
|
" def __init__(self, main, skip=None):\n", |
|
" super().__init__()\n", |
|
" self.main = nn.Sequential(*main)\n", |
|
" self.skip = skip if skip else nn.Identity()\n", |
|
"\n", |
|
" def forward(self, input):\n", |
|
" return torch.cat([self.main(input), self.skip(input)], dim=1)\n", |
|
"\n", |
|
"\n", |
|
"class FourierFeatures(nn.Module):\n", |
|
" def __init__(self, in_features, out_features, std=1.):\n", |
|
" super().__init__()\n", |
|
" assert out_features % 2 == 0\n", |
|
" self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std)\n", |
|
"\n", |
|
" def forward(self, input):\n", |
|
" f = 2 * math.pi * input @ self.weight.T\n", |
|
" return torch.cat([f.cos(), f.sin()], dim=-1)\n", |
|
"\n", |
|
"\n", |
|
"class SecondaryDiffusionImageNet(nn.Module):\n", |
|
" def __init__(self):\n", |
|
" super().__init__()\n", |
|
" c = 64 # The base channel count\n", |
|
"\n", |
|
" self.timestep_embed = FourierFeatures(1, 16)\n", |
|
"\n", |
|
" self.net = nn.Sequential(\n", |
|
" ConvBlock(3 + 16, c),\n", |
|
" ConvBlock(c, c),\n", |
|
" SkipBlock([\n", |
|
" nn.AvgPool2d(2),\n", |
|
" ConvBlock(c, c * 2),\n", |
|
" ConvBlock(c * 2, c * 2),\n", |
|
" SkipBlock([\n", |
|
" nn.AvgPool2d(2),\n", |
|
" ConvBlock(c * 2, c * 4),\n", |
|
" ConvBlock(c * 4, c * 4),\n", |
|
" SkipBlock([\n", |
|
" nn.AvgPool2d(2),\n", |
|
" ConvBlock(c * 4, c * 8),\n", |
|
" ConvBlock(c * 8, c * 4),\n", |
|
" nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n", |
|
" ]),\n", |
|
" ConvBlock(c * 8, c * 4),\n", |
|
" ConvBlock(c * 4, c * 2),\n", |
|
" nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n", |
|
" ]),\n", |
|
" ConvBlock(c * 4, c * 2),\n", |
|
" ConvBlock(c * 2, c),\n", |
|
" nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n", |
|
" ]),\n", |
|
" ConvBlock(c * 2, c),\n", |
|
" nn.Conv2d(c, 3, 3, padding=1),\n", |
|
" )\n", |
|
"\n", |
|
" def forward(self, input, t):\n", |
|
" timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)\n", |
|
" v = self.net(torch.cat([input, timestep_embed], dim=1))\n", |
|
" alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))\n", |
|
" pred = input * alphas - v * sigmas\n", |
|
" eps = input * sigmas + v * alphas\n", |
|
" return DiffusionOutput(v, pred, eps)\n", |
|
"\n", |
|
"\n", |
|
"class SecondaryDiffusionImageNet2(nn.Module):\n", |
|
" def __init__(self):\n", |
|
" super().__init__()\n", |
|
" c = 64 # The base channel count\n", |
|
" cs = [c, c * 2, c * 2, c * 4, c * 4, c * 8]\n", |
|
"\n", |
|
" self.timestep_embed = FourierFeatures(1, 16)\n", |
|
" self.down = nn.AvgPool2d(2)\n", |
|
" self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)\n", |
|
"\n", |
|
" self.net = nn.Sequential(\n", |
|
" ConvBlock(3 + 16, cs[0]),\n", |
|
" ConvBlock(cs[0], cs[0]),\n", |
|
" SkipBlock([\n", |
|
" self.down,\n", |
|
" ConvBlock(cs[0], cs[1]),\n", |
|
" ConvBlock(cs[1], cs[1]),\n", |
|
" SkipBlock([\n", |
|
" self.down,\n", |
|
" ConvBlock(cs[1], cs[2]),\n", |
|
" ConvBlock(cs[2], cs[2]),\n", |
|
" SkipBlock([\n", |
|
" self.down,\n", |
|
" ConvBlock(cs[2], cs[3]),\n", |
|
" ConvBlock(cs[3], cs[3]),\n", |
|
" SkipBlock([\n", |
|
" self.down,\n", |
|
" ConvBlock(cs[3], cs[4]),\n", |
|
" ConvBlock(cs[4], cs[4]),\n", |
|
" SkipBlock([\n", |
|
" self.down,\n", |
|
" ConvBlock(cs[4], cs[5]),\n", |
|
" ConvBlock(cs[5], cs[5]),\n", |
|
" ConvBlock(cs[5], cs[5]),\n", |
|
" ConvBlock(cs[5], cs[4]),\n", |
|
" self.up,\n", |
|
" ]),\n", |
|
" ConvBlock(cs[4] * 2, cs[4]),\n", |
|
" ConvBlock(cs[4], cs[3]),\n", |
|
" self.up,\n", |
|
" ]),\n", |
|
" ConvBlock(cs[3] * 2, cs[3]),\n", |
|
" ConvBlock(cs[3], cs[2]),\n", |
|
" self.up,\n", |
|
" ]),\n", |
|
" ConvBlock(cs[2] * 2, cs[2]),\n", |
|
" ConvBlock(cs[2], cs[1]),\n", |
|
" self.up,\n", |
|
" ]),\n", |
|
" ConvBlock(cs[1] * 2, cs[1]),\n", |
|
" ConvBlock(cs[1], cs[0]),\n", |
|
" self.up,\n", |
|
" ]),\n", |
|
" ConvBlock(cs[0] * 2, cs[0]),\n", |
|
" nn.Conv2d(cs[0], 3, 3, padding=1),\n", |
|
" )\n", |
|
"\n", |
|
" def forward(self, input, t):\n", |
|
" timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)\n", |
|
" v = self.net(torch.cat([input, timestep_embed], dim=1))\n", |
|
" alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))\n", |
|
" pred = input * alphas - v * sigmas\n", |
|
" eps = input * sigmas + v * alphas\n", |
|
" return DiffusionOutput(v, pred, eps)" |
|
], |
|
"outputs": [], |
|
"execution_count": null |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"metadata": {}, |
|
"source": [ |
|
"#@title 1.7 SuperRes Define\n", |
|
"class DDIMSampler(object):\n", |
|
" def __init__(self, model, schedule=\"linear\", **kwargs):\n", |
|
" super().__init__()\n", |
|
" self.model = model\n", |
|
" self.ddpm_num_timesteps = model.num_timesteps\n", |
|
" self.schedule = schedule\n", |
|
"\n", |
|
" def register_buffer(self, name, attr):\n", |
|
" if type(attr) == torch.Tensor:\n", |
|
" if attr.device != torch.device(\"cuda\"):\n", |
|
" attr = attr.to(torch.device(\"cuda\"))\n", |
|
" setattr(self, name, attr)\n", |
|
"\n", |
|
" def make_schedule(self, ddim_num_steps, ddim_discretize=\"uniform\", ddim_eta=0., verbose=True):\n", |
|
" self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,\n", |
|
" num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)\n", |
|
" alphas_cumprod = self.model.alphas_cumprod\n", |
|
" assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'\n", |
|
" to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)\n", |
|
"\n", |
|
" self.register_buffer('betas', to_torch(self.model.betas))\n", |
|
" self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))\n", |
|
" self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))\n", |
|
"\n", |
|
" # calculations for diffusion q(x_t | x_{t-1}) and others\n", |
|
" self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))\n", |
|
" self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))\n", |
|
" self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))\n", |
|
" self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))\n", |
|
" self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))\n", |
|
"\n", |
|
" # ddim sampling parameters\n", |
|
" ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),\n", |
|
" ddim_timesteps=self.ddim_timesteps,\n", |
|
" eta=ddim_eta,verbose=verbose)\n", |
|
" self.register_buffer('ddim_sigmas', ddim_sigmas)\n", |
|
" self.register_buffer('ddim_alphas', ddim_alphas)\n", |
|
" self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)\n", |
|
" self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))\n", |
|
" sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(\n", |
|
" (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (\n", |
|
" 1 - self.alphas_cumprod / self.alphas_cumprod_prev))\n", |
|
" self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)\n", |
|
"\n", |
|
" @torch.no_grad()\n", |
|
" def sample(self,\n", |
|
" S,\n", |
|
" batch_size,\n", |
|
" shape,\n", |
|
" conditioning=None,\n", |
|
" callback=None,\n", |
|
" normals_sequence=None,\n", |
|
" img_callback=None,\n", |
|
" quantize_x0=False,\n", |
|
" eta=0.,\n", |
|
" mask=None,\n", |
|
" x0=None,\n", |
|
" temperature=1.,\n", |
|
" noise_dropout=0.,\n", |
|
" score_corrector=None,\n", |
|
" corrector_kwargs=None,\n", |
|
" verbose=True,\n", |
|
" x_T=None,\n", |
|
" log_every_t=100,\n", |
|
" **kwargs\n", |
|
" ):\n", |
|
" if conditioning is not None:\n", |
|
" if isinstance(conditioning, dict):\n", |
|
" cbs = conditioning[list(conditioning.keys())[0]].shape[0]\n", |
|
" if cbs != batch_size:\n", |
|
" print(f\"Warning: Got {cbs} conditionings but batch-size is {batch_size}\")\n", |
|
" else:\n", |
|
" if conditioning.shape[0] != batch_size:\n", |
|
" print(f\"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}\")\n", |
|
"\n", |
|
" self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)\n", |
|
" # sampling\n", |
|
" C, H, W = shape\n", |
|
" size = (batch_size, C, H, W)\n", |
|
" # print(f'Data shape for DDIM sampling is {size}, eta {eta}')\n", |
|
"\n", |
|
" samples, intermediates = self.ddim_sampling(conditioning, size,\n", |
|
" callback=callback,\n", |
|
" img_callback=img_callback,\n", |
|
" quantize_denoised=quantize_x0,\n", |
|
" mask=mask, x0=x0,\n", |
|
" ddim_use_original_steps=False,\n", |
|
" noise_dropout=noise_dropout,\n", |
|
" temperature=temperature,\n", |
|
" score_corrector=score_corrector,\n", |
|
" corrector_kwargs=corrector_kwargs,\n", |
|
" x_T=x_T,\n", |
|
" log_every_t=log_every_t\n", |
|
" )\n", |
|
" return samples, intermediates\n", |
|
"\n", |
|
" @torch.no_grad()\n", |
|
" def ddim_sampling(self, cond, shape,\n", |
|
" x_T=None, ddim_use_original_steps=False,\n", |
|
" callback=None, timesteps=None, quantize_denoised=False,\n", |
|
" mask=None, x0=None, img_callback=None, log_every_t=100,\n", |
|
" temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n", |
|
" device = self.model.betas.device\n", |
|
" b = shape[0]\n", |
|
" if x_T is None:\n", |
|
" img = torch.randn(shape, device=device)\n", |
|
" else:\n", |
|
" img = x_T\n", |
|
"\n", |
|
" if timesteps is None:\n", |
|
" timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps\n", |
|
" elif timesteps is not None and not ddim_use_original_steps:\n", |
|
" subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1\n", |
|
" timesteps = self.ddim_timesteps[:subset_end]\n", |
|
"\n", |
|
" intermediates = {'x_inter': [img], 'pred_x0': [img]}\n", |
|
" time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)\n", |
|
" total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]\n", |
|
" print(f\"Running DDIM Sharpening with {total_steps} timesteps\")\n", |
|
"\n", |
|
" iterator = tqdm(time_range, desc='DDIM Sharpening', total=total_steps)\n", |
|
"\n", |
|
" for i, step in enumerate(iterator):\n", |
|
" index = total_steps - i - 1\n", |
|
" ts = torch.full((b,), step, device=device, dtype=torch.long)\n", |
|
"\n", |
|
" if mask is not None:\n", |
|
" assert x0 is not None\n", |
|
" img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?\n", |
|
" img = img_orig * mask + (1. - mask) * img\n", |
|
"\n", |
|
" outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,\n", |
|
" quantize_denoised=quantize_denoised, temperature=temperature,\n", |
|
" noise_dropout=noise_dropout, score_corrector=score_corrector,\n", |
|
" corrector_kwargs=corrector_kwargs)\n", |
|
" img, pred_x0 = outs\n", |
|
" if callback: callback(i)\n", |
|
" if img_callback: img_callback(pred_x0, i)\n", |
|
"\n", |
|
" if index % log_every_t == 0 or index == total_steps - 1:\n", |
|
" intermediates['x_inter'].append(img)\n", |
|
" intermediates['pred_x0'].append(pred_x0)\n", |
|
"\n", |
|
" return img, intermediates\n", |
|
"\n", |
|
" @torch.no_grad()\n", |
|
" def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,\n", |
|
" temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n", |
|
" b, *_, device = *x.shape, x.device\n", |
|
" e_t = self.model.apply_model(x, t, c)\n", |
|
" if score_corrector is not None:\n", |
|
" assert self.model.parameterization == \"eps\"\n", |
|
" e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)\n", |
|
"\n", |
|
" alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas\n", |
|
" alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev\n", |
|
" sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas\n", |
|
" sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas\n", |
|
" # select parameters corresponding to the currently considered timestep\n", |
|
" a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)\n", |
|
" a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)\n", |
|
" sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)\n", |
|
" sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)\n", |
|
"\n", |
|
" # current prediction for x_0\n", |
|
" pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()\n", |
|
" if quantize_denoised:\n", |
|
" pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)\n", |
|
" # direction pointing to x_t\n", |
|
" dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t\n", |
|
" noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature\n", |
|
" if noise_dropout > 0.:\n", |
|
" noise = torch.nn.functional.dropout(noise, p=noise_dropout)\n", |
|
" x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise\n", |
|
" return x_prev, pred_x0\n", |
|
"\n", |
|
"\n", |
|
"def download_models(mode):\n", |
|
"\n", |
|
" if mode == \"superresolution\":\n", |
|
" # this is the small bsr light model\n", |
|
" url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1'\n", |
|
" url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1'\n", |
|
"\n", |
|
" path_conf = f'{model_path}/superres/project.yaml'\n", |
|
" path_ckpt = f'{model_path}/superres/last.ckpt'\n", |
|
"\n", |
|
" download_url(url_conf, path_conf)\n", |
|
" download_url(url_ckpt, path_ckpt)\n", |
|
"\n", |
|
" path_conf = path_conf + '/?dl=1' # fix it\n", |
|
" path_ckpt = path_ckpt + '/?dl=1' # fix it\n", |
|
" return path_conf, path_ckpt\n", |
|
"\n", |
|
" else:\n", |
|
" raise NotImplementedError\n", |
|
"\n", |
|
"\n", |
|
"def load_model_from_config(config, ckpt):\n", |
|
" print(f\"Loading model from {ckpt}\")\n", |
|
" pl_sd = torch.load(ckpt, map_location=\"cpu\")\n", |
|
" global_step = pl_sd[\"global_step\"]\n", |
|
" sd = pl_sd[\"state_dict\"]\n", |
|
" model = instantiate_from_config(config.model)\n", |
|
" m, u = model.load_state_dict(sd, strict=False)\n", |
|
" model.cuda()\n", |
|
" model.eval()\n", |
|
" return {\"model\": model}, global_step\n", |
|
"\n", |
|
"\n", |
|
"def get_model(mode):\n", |
|
" path_conf, path_ckpt = download_models(mode)\n", |
|
" config = OmegaConf.load(path_conf)\n", |
|
" model, step = load_model_from_config(config, path_ckpt)\n", |
|
" return model\n", |
|
"\n", |
|
"\n", |
|
"def get_custom_cond(mode):\n", |
|
" dest = \"data/example_conditioning\"\n", |
|
"\n", |
|
" if mode == \"superresolution\":\n", |
|
" uploaded_img = files.upload()\n", |
|
" filename = next(iter(uploaded_img))\n", |
|
" name, filetype = filename.split(\".\") # todo assumes just one dot in name !\n", |
|
" os.rename(f\"{filename}\", f\"{dest}/{mode}/custom_{name}.{filetype}\")\n", |
|
"\n", |
|
" elif mode == \"text_conditional\":\n", |
|
" w = widgets.Text(value='A cake with cream!', disabled=True)\n", |
|
" display.display(w)\n", |
|
"\n", |
|
" with open(f\"{dest}/{mode}/custom_{w.value[:20]}.txt\", 'w') as f:\n", |
|
" f.write(w.value)\n", |
|
"\n", |
|
" elif mode == \"class_conditional\":\n", |
|
" w = widgets.IntSlider(min=0, max=1000)\n", |
|
" display.display(w)\n", |
|
" with open(f\"{dest}/{mode}/custom.txt\", 'w') as f:\n", |
|
" f.write(w.value)\n", |
|
"\n", |
|
" else:\n", |
|
" raise NotImplementedError(f\"cond not implemented for mode{mode}\")\n", |
|
"\n", |
|
"\n", |
|
"def get_cond_options(mode):\n", |
|
" path = \"data/example_conditioning\"\n", |
|
" path = os.path.join(path, mode)\n", |
|
" onlyfiles = [f for f in sorted(os.listdir(path))]\n", |
|
" return path, onlyfiles\n", |
|
"\n", |
|
"\n", |
|
"def select_cond_path(mode):\n", |
|
" path = \"data/example_conditioning\" # todo\n", |
|
" path = os.path.join(path, mode)\n", |
|
" onlyfiles = [f for f in sorted(os.listdir(path))]\n", |
|
"\n", |
|
" selected = widgets.RadioButtons(\n", |
|
" options=onlyfiles,\n", |
|
" description='Select conditioning:',\n", |
|
" disabled=False\n", |
|
" )\n", |
|
" display.display(selected)\n", |
|
" selected_path = os.path.join(path, selected.value)\n", |
|
" return selected_path\n", |
|
"\n", |
|
"\n", |
|
"def get_cond(mode, img):\n", |
|
" example = dict()\n", |
|
" if mode == \"superresolution\":\n", |
|
" up_f = 4\n", |
|
" # visualize_cond_img(selected_path)\n", |
|
"\n", |
|
" c = img\n", |
|
" c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)\n", |
|
" c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True)\n", |
|
" c_up = rearrange(c_up, '1 c h w -> 1 h w c')\n", |
|
" c = rearrange(c, '1 c h w -> 1 h w c')\n", |
|
" c = 2. * c - 1.\n", |
|
"\n", |
|
" c = c.to(torch.device(\"cuda\"))\n", |
|
" example[\"LR_image\"] = c\n", |
|
" example[\"image\"] = c_up\n", |
|
"\n", |
|
" return example\n", |
|
"\n", |
|
"\n", |
|
"def visualize_cond_img(path):\n", |
|
" display.display(ipyimg(filename=path))\n", |
|
"\n", |
|
"\n", |
|
"def sr_run(model, img, task, custom_steps, eta, resize_enabled=False, classifier_ckpt=None, global_step=None):\n", |
|
" # global stride\n", |
|
"\n", |
|
" example = get_cond(task, img)\n", |
|
"\n", |
|
" save_intermediate_vid = False\n", |
|
" n_runs = 1\n", |
|
" masked = False\n", |
|
" guider = None\n", |
|
" ckwargs = None\n", |
|
" mode = 'ddim'\n", |
|
" ddim_use_x0_pred = False\n", |
|
" temperature = 1.\n", |
|
" eta = eta\n", |
|
" make_progrow = True\n", |
|
" custom_shape = None\n", |
|
"\n", |
|
" height, width = example[\"image\"].shape[1:3]\n", |
|
" split_input = height >= 128 and width >= 128\n", |
|
"\n", |
|
" if split_input:\n", |
|
" ks = 128\n", |
|
" stride = 64\n", |
|
" vqf = 4 #\n", |
|
" model.split_input_params = {\"ks\": (ks, ks), \"stride\": (stride, stride),\n", |
|
" \"vqf\": vqf,\n", |
|
" \"patch_distributed_vq\": True,\n", |
|
" \"tie_braker\": False,\n", |
|
" \"clip_max_weight\": 0.5,\n", |
|
" \"clip_min_weight\": 0.01,\n", |
|
" \"clip_max_tie_weight\": 0.5,\n", |
|
" \"clip_min_tie_weight\": 0.01}\n", |
|
" else:\n", |
|
" if hasattr(model, \"split_input_params\"):\n", |
|
" delattr(model, \"split_input_params\")\n", |
|
"\n", |
|
" invert_mask = False\n", |
|
"\n", |
|
" x_T = None\n", |
|
" for n in range(n_runs):\n", |
|
" if custom_shape is not None:\n", |
|
" x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)\n", |
|
" x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0])\n", |
|
"\n", |
|
" logs = make_convolutional_sample(example, model,\n", |
|
" mode=mode, custom_steps=custom_steps,\n", |
|
" eta=eta, swap_mode=False , masked=masked,\n", |
|
" invert_mask=invert_mask, quantize_x0=False,\n", |
|
" custom_schedule=None, decode_interval=10,\n", |
|
" resize_enabled=resize_enabled, custom_shape=custom_shape,\n", |
|
" temperature=temperature, noise_dropout=0.,\n", |
|
" corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid,\n", |
|
" make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred\n", |
|
" )\n", |
|
" return logs\n", |
|
"\n", |
|
"\n", |
|
"@torch.no_grad()\n", |
|
"def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,\n", |
|
" mask=None, x0=None, quantize_x0=False, img_callback=None,\n", |
|
" temperature=1., noise_dropout=0., score_corrector=None,\n", |
|
" corrector_kwargs=None, x_T=None, log_every_t=None\n", |
|
" ):\n", |
|
"\n", |
|
" ddim = DDIMSampler(model)\n", |
|
" bs = shape[0] # dont know where this comes from but wayne\n", |
|
" shape = shape[1:] # cut batch dim\n", |
|
" # print(f\"Sampling with eta = {eta}; steps: {steps}\")\n", |
|
" samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,\n", |
|
" normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,\n", |
|
" mask=mask, x0=x0, temperature=temperature, verbose=False,\n", |
|
" score_corrector=score_corrector,\n", |
|
" corrector_kwargs=corrector_kwargs, x_T=x_T)\n", |
|
"\n", |
|
" return samples, intermediates\n", |
|
"\n", |
|
"\n", |
|
"@torch.no_grad()\n", |
|
"def make_convolutional_sample(batch, model, mode=\"vanilla\", custom_steps=None, eta=1.0, swap_mode=False, masked=False,\n", |
|
" invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000,\n", |
|
" resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,\n", |
|
" corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False):\n", |
|
" log = dict()\n", |
|
"\n", |
|
" z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,\n", |
|
" return_first_stage_outputs=True,\n", |
|
" force_c_encode=not (hasattr(model, 'split_input_params')\n", |
|
" and model.cond_stage_key == 'coordinates_bbox'),\n", |
|
" return_original_cond=True)\n", |
|
"\n", |
|
" log_every_t = 1 if save_intermediate_vid else None\n", |
|
"\n", |
|
" if custom_shape is not None:\n", |
|
" z = torch.randn(custom_shape)\n", |
|
" # print(f\"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}\")\n", |
|
"\n", |
|
" z0 = None\n", |
|
"\n", |
|
" log[\"input\"] = x\n", |
|
" log[\"reconstruction\"] = xrec\n", |
|
"\n", |
|
" if ismap(xc):\n", |
|
" log[\"original_conditioning\"] = model.to_rgb(xc)\n", |
|
" if hasattr(model, 'cond_stage_key'):\n", |
|
" log[model.cond_stage_key] = model.to_rgb(xc)\n", |
|
"\n", |
|
" else:\n", |
|
" log[\"original_conditioning\"] = xc if xc is not None else torch.zeros_like(x)\n", |
|
" if model.cond_stage_model:\n", |
|
" log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)\n", |
|
" if model.cond_stage_key =='class_label':\n", |
|
" log[model.cond_stage_key] = xc[model.cond_stage_key]\n", |
|
"\n", |
|
" with model.ema_scope(\"Plotting\"):\n", |
|
" t0 = time.time()\n", |
|
" img_cb = None\n", |
|
"\n", |
|
" sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,\n", |
|
" eta=eta,\n", |
|
" quantize_x0=quantize_x0, img_callback=img_cb, mask=None, x0=z0,\n", |
|
" temperature=temperature, noise_dropout=noise_dropout,\n", |
|
" score_corrector=corrector, corrector_kwargs=corrector_kwargs,\n", |
|
" x_T=x_T, log_every_t=log_every_t)\n", |
|
" t1 = time.time()\n", |
|
"\n", |
|
" if ddim_use_x0_pred:\n", |
|
" sample = intermediates['pred_x0'][-1]\n", |
|
"\n", |
|
" x_sample = model.decode_first_stage(sample)\n", |
|
"\n", |
|
" try:\n", |
|
" x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)\n", |
|
" log[\"sample_noquant\"] = x_sample_noquant\n", |
|
" log[\"sample_diff\"] = torch.abs(x_sample_noquant - x_sample)\n", |
|
" except:\n", |
|
" pass\n", |
|
"\n", |
|
" log[\"sample\"] = x_sample\n", |
|
" log[\"time\"] = t1 - t0\n", |
|
"\n", |
|
" return log\n", |
|
"\n", |
|
"sr_diffMode = 'superresolution'\n", |
|
"sr_model = get_model('superresolution')\n", |
|
"\n", |
|
"\n", |
|
"\n", |
|
"\n", |
|
"\n", |
|
"\n", |
|
"def do_superres(img, filepath):\n", |
|
"\n", |
|
" if args.sharpen_preset == 'Faster':\n", |
|
" sr_diffusion_steps = \"25\" \n", |
|
" sr_pre_downsample = '1/2' \n", |
|
" if args.sharpen_preset == 'Fast':\n", |
|
" sr_diffusion_steps = \"100\" \n", |
|
" sr_pre_downsample = '1/2' \n", |
|
" if args.sharpen_preset == 'Slow':\n", |
|
" sr_diffusion_steps = \"25\" \n", |
|
" sr_pre_downsample = 'None' \n", |
|
" if args.sharpen_preset == 'Very Slow':\n", |
|
" sr_diffusion_steps = \"100\" \n", |
|
" sr_pre_downsample = 'None' \n", |
|
"\n", |
|
"\n", |
|
" sr_post_downsample = 'Original Size'\n", |
|
" sr_diffusion_steps = int(sr_diffusion_steps)\n", |
|
" sr_eta = 1.0 \n", |
|
" sr_downsample_method = 'Lanczos' \n", |
|
"\n", |
|
" gc.collect()\n", |
|
" torch.cuda.empty_cache()\n", |
|
"\n", |
|
" im_og = img\n", |
|
" width_og, height_og = im_og.size\n", |
|
"\n", |
|
" #Downsample Pre\n", |
|
" if sr_pre_downsample == '1/2':\n", |
|
" downsample_rate = 2\n", |
|
" elif sr_pre_downsample == '1/4':\n", |
|
" downsample_rate = 4\n", |
|
" else:\n", |
|
" downsample_rate = 1\n", |
|
"\n", |
|
" width_downsampled_pre = width_og//downsample_rate\n", |
|
" height_downsampled_pre = height_og//downsample_rate\n", |
|
"\n", |
|
" if downsample_rate != 1:\n", |
|
" # print(f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')\n", |
|
" im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)\n", |
|
" # im_og.save('/content/temp.png')\n", |
|
" # filepath = '/content/temp.png'\n", |
|
"\n", |
|
" logs = sr_run(sr_model[\"model\"], im_og, sr_diffMode, sr_diffusion_steps, sr_eta)\n", |
|
"\n", |
|
" sample = logs[\"sample\"]\n", |
|
" sample = sample.detach().cpu()\n", |
|
" sample = torch.clamp(sample, -1., 1.)\n", |
|
" sample = (sample + 1.) / 2. * 255\n", |
|
" sample = sample.numpy().astype(np.uint8)\n", |
|
" sample = np.transpose(sample, (0, 2, 3, 1))\n", |
|
" a = Image.fromarray(sample[0])\n", |
|
"\n", |
|
" #Downsample Post\n", |
|
" if sr_post_downsample == '1/2':\n", |
|
" downsample_rate = 2\n", |
|
" elif sr_post_downsample == '1/4':\n", |
|
" downsample_rate = 4\n", |
|
" else:\n", |
|
" downsample_rate = 1\n", |
|
"\n", |
|
" width, height = a.size\n", |
|
" width_downsampled_post = width//downsample_rate\n", |
|
" height_downsampled_post = height//downsample_rate\n", |
|
"\n", |
|
" if sr_downsample_method == 'Lanczos':\n", |
|
" aliasing = Image.LANCZOS\n", |
|
" else:\n", |
|
" aliasing = Image.NEAREST\n", |
|
"\n", |
|
" if downsample_rate != 1:\n", |
|
" # print(f'Downsampling from [{width}, {height}] to [{width_downsampled_post}, {height_downsampled_post}]')\n", |
|
" a = a.resize((width_downsampled_post, height_downsampled_post), aliasing)\n", |
|
" elif sr_post_downsample == 'Original Size':\n", |
|
" # print(f'Downsampling from [{width}, {height}] to Original Size [{width_og}, {height_og}]')\n", |
|
" a = a.resize((width_og, height_og), aliasing)\n", |
|
"\n", |
|
" display.display(a)\n", |
|
" a.save(filepath)\n", |
|
" return\n", |
|
" print(f'Processing finished!')\n" |
|
], |
|
"outputs": [], |
|
"execution_count": null |
|
}, |
|
{ |
|
"cell_type": "markdown", |
|
"metadata": {}, |
|
"source": [ |
|
"# 2. Diffusion and CLIP model settings" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"metadata": {}, |
|
"source": [ |
|
"#@markdown ####**Models Settings:**\n", |
|
"diffusion_model = \"512x512_diffusion_uncond_finetune_008100\" #@param [\"256x256_diffusion_uncond\", \"512x512_diffusion_uncond_finetune_008100\"]\n", |
|
"use_secondary_model = True #@param {type: 'boolean'}\n", |
|
"sampling_mode = 'ddim' #@param ['plms','ddim'] \n", |
|
"\n", |
|
"timestep_respacing = '250' # param ['25','50','100','150','250','500','1000','ddim25','ddim50', 'ddim75', 'ddim100','ddim150','ddim250','ddim500','ddim1000'] \n", |
|
"diffusion_steps = 1000 # param {type: 'number'}\n", |
|
"use_checkpoint = True #@param {type: 'boolean'}\n", |
|
"ViTB32 = True #@param{type:\"boolean\"}\n", |
|
"ViTB16 = True #@param{type:\"boolean\"}\n", |
|
"ViTL14 = False #@param{type:\"boolean\"}\n", |
|
"RN101 = False #@param{type:\"boolean\"}\n", |
|
"RN50 = True #@param{type:\"boolean\"}\n", |
|
"RN50x4 = False #@param{type:\"boolean\"}\n", |
|
"RN50x16 = False #@param{type:\"boolean\"}\n", |
|
"RN50x64 = False #@param{type:\"boolean\"}\n", |
|
"SLIPB16 = False # param{type:\"boolean\"}\n", |
|
"SLIPL16 = False # param{type:\"boolean\"}\n", |
|
"\n", |
|
"#@markdown If you're having issues with model downloads, check this to compare SHA's:\n", |
|
"check_model_SHA = False #@param{type:\"boolean\"}\n", |
|
"\n", |
|
"model_256_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'\n", |
|
"model_512_SHA = '9c111ab89e214862b76e1fa6a1b3f1d329b1a88281885943d2cdbe357ad57648'\n", |
|
"model_secondary_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'\n", |
|
"\n", |
|
"model_256_link = 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'\n", |
|
"model_512_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/512x512_diffusion_uncond_finetune_008100.pt'\n", |
|
"model_secondary_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth'\n", |
|
"\n", |
|
"model_256_path = f'{model_path}/256x256_diffusion_uncond.pt'\n", |
|
"model_512_path = f'{model_path}/512x512_diffusion_uncond_finetune_008100.pt'\n", |
|
"model_secondary_path = f'{model_path}/secondary_model_imagenet_2.pth'\n", |
|
"\n", |
|
"# Download the diffusion model\n", |
|
"if diffusion_model == '256x256_diffusion_uncond':\n", |
|
" if os.path.exists(model_256_path) and check_model_SHA:\n", |
|
" print('Checking 256 Diffusion File')\n", |
|
" with open(model_256_path,\"rb\") as f:\n", |
|
" bytes = f.read() \n", |
|
" hash = hashlib.sha256(bytes).hexdigest();\n", |
|
" if hash == model_256_SHA:\n", |
|
" print('256 Model SHA matches')\n", |
|
" model_256_downloaded = True\n", |
|
" else: \n", |
|
" print(\"256 Model SHA doesn't match, redownloading...\")\n", |
|
" wget(model_256_link, model_path)\n", |
|
" model_256_downloaded = True\n", |
|
" elif os.path.exists(model_256_path) and not check_model_SHA or model_256_downloaded == True:\n", |
|
" print('256 Model already downloaded, check check_model_SHA if the file is corrupt')\n", |
|
" else: \n", |
|
" wget(model_256_link, model_path)\n", |
|
" model_256_downloaded = True\n", |
|
"elif diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n", |
|
" if os.path.exists(model_512_path) and check_model_SHA:\n", |
|
" print('Checking 512 Diffusion File')\n", |
|
" with open(model_512_path,\"rb\") as f:\n", |
|
" bytes = f.read() \n", |
|
" hash = hashlib.sha256(bytes).hexdigest();\n", |
|
" if hash == model_512_SHA:\n", |
|
" print('512 Model SHA matches')\n", |
|
" model_512_downloaded = True\n", |
|
" else: \n", |
|
" print(\"512 Model SHA doesn't match, redownloading...\")\n", |
|
" wget(model_512_link, model_path)\n", |
|
" model_512_downloaded = True\n", |
|
" elif os.path.exists(model_512_path) and not check_model_SHA or model_512_downloaded == True:\n", |
|
" print('512 Model already downloaded, check check_model_SHA if the file is corrupt')\n", |
|
" else: \n", |
|
" wget(model_512_link, model_path)\n", |
|
" model_512_downloaded = True\n", |
|
"\n", |
|
"\n", |
|
"# Download the secondary diffusion model v2\n", |
|
"if use_secondary_model == True:\n", |
|
" if os.path.exists(model_secondary_path) and check_model_SHA:\n", |
|
" print('Checking Secondary Diffusion File')\n", |
|
" with open(model_secondary_path,\"rb\") as f:\n", |
|
" bytes = f.read() \n", |
|
" hash = hashlib.sha256(bytes).hexdigest();\n", |
|
" if hash == model_secondary_SHA:\n", |
|
" print('Secondary Model SHA matches')\n", |
|
" model_secondary_downloaded = True\n", |
|
" else: \n", |
|
" print(\"Secondary Model SHA doesn't match, redownloading...\")\n", |
|
" wget(model_secondary_link, model_path)\n", |
|
" model_secondary_downloaded = True\n", |
|
" elif os.path.exists(model_secondary_path) and not check_model_SHA or model_secondary_downloaded == True:\n", |
|
" print('Secondary Model already downloaded, check check_model_SHA if the file is corrupt')\n", |
|
" else: \n", |
|
" wget(model_secondary_link, model_path)\n", |
|
" model_secondary_downloaded = True\n", |
|
"\n", |
|
"model_config = model_and_diffusion_defaults()\n", |
|
"if diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n", |
|
" model_config.update({\n", |
|
" 'attention_resolutions': '32, 16, 8',\n", |
|
" 'class_cond': False,\n", |
|
" 'diffusion_steps': diffusion_steps,\n", |
|
" 'rescale_timesteps': True,\n", |
|
" 'timestep_respacing': timestep_respacing,\n", |
|
" 'image_size': 512,\n", |
|
" 'learn_sigma': True,\n", |
|
" 'noise_schedule': 'linear',\n", |
|
" 'num_channels': 256,\n", |
|
" 'num_head_channels': 64,\n", |
|
" 'num_res_blocks': 2,\n", |
|
" 'resblock_updown': True,\n", |
|
" 'use_checkpoint': use_checkpoint,\n", |
|
" 'use_fp16': True,\n", |
|
" 'use_scale_shift_norm': True,\n", |
|
" })\n", |
|
"elif diffusion_model == '256x256_diffusion_uncond':\n", |
|
" model_config.update({\n", |
|
" 'attention_resolutions': '32, 16, 8',\n", |
|
" 'class_cond': False,\n", |
|
" 'diffusion_steps': diffusion_steps,\n", |
|
" 'rescale_timesteps': True,\n", |
|
" 'timestep_respacing': timestep_respacing,\n", |
|
" 'image_size': 256,\n", |
|
" 'learn_sigma': True,\n", |
|
" 'noise_schedule': 'linear',\n", |
|
" 'num_channels': 256,\n", |
|
" 'num_head_channels': 64,\n", |
|
" 'num_res_blocks': 2,\n", |
|
" 'resblock_updown': True,\n", |
|
" 'use_checkpoint': use_checkpoint,\n", |
|
" 'use_fp16': True,\n", |
|
" 'use_scale_shift_norm': True,\n", |
|
" })\n", |
|
"\n", |
|
"secondary_model_ver = 2\n", |
|
"model_default = model_config['image_size']\n", |
|
"\n", |
|
"\n", |
|
"\n", |
|
"if secondary_model_ver == 2:\n", |
|
" secondary_model = SecondaryDiffusionImageNet2()\n", |
|
" secondary_model.load_state_dict(torch.load(f'{model_path}/secondary_model_imagenet_2.pth', map_location='cpu'))\n", |
|
"secondary_model.eval().requires_grad_(False).to(device)\n", |
|
"\n", |
|
"clip_models = []\n", |
|
"if ViTB32 is True: clip_models.append(clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device)) \n", |
|
"if ViTB16 is True: clip_models.append(clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device) ) \n", |
|
"if ViTL14 is True: clip_models.append(clip.load('ViT-L/14', jit=False)[0].eval().requires_grad_(False).to(device) ) \n", |
|
"if RN50 is True: clip_models.append(clip.load('RN50', jit=False)[0].eval().requires_grad_(False).to(device))\n", |
|
"if RN50x4 is True: clip_models.append(clip.load('RN50x4', jit=False)[0].eval().requires_grad_(False).to(device)) \n", |
|
"if RN50x16 is True: clip_models.append(clip.load('RN50x16', jit=False)[0].eval().requires_grad_(False).to(device)) \n", |
|
"if RN50x64 is True: clip_models.append(clip.load('RN50x64', jit=False)[0].eval().requires_grad_(False).to(device)) \n", |
|
"if RN101 is True: clip_models.append(clip.load('RN101', jit=False)[0].eval().requires_grad_(False).to(device)) \n", |
|
"\n", |
|
"if SLIPB16:\n", |
|
" SLIPB16model = SLIP_VITB16(ssl_mlp_dim=4096, ssl_emb_dim=256)\n", |
|
" if not os.path.exists(f'{model_path}/slip_base_100ep.pt'):\n", |
|
" wget(\"https://dl.fbaipublicfiles.com/slip/slip_base_100ep.pt\", model_path)\n", |
|
" sd = torch.load(f'{model_path}/slip_base_100ep.pt')\n", |
|
" real_sd = {}\n", |
|
" for k, v in sd['state_dict'].items():\n", |
|
" real_sd['.'.join(k.split('.')[1:])] = v\n", |
|
" del sd\n", |
|
" SLIPB16model.load_state_dict(real_sd)\n", |
|
" SLIPB16model.requires_grad_(False).eval().to(device)\n", |
|
"\n", |
|
" clip_models.append(SLIPB16model)\n", |
|
"\n", |
|
"if SLIPL16:\n", |
|
" SLIPL16model = SLIP_VITL16(ssl_mlp_dim=4096, ssl_emb_dim=256)\n", |
|
" if not os.path.exists(f'{model_path}/slip_large_100ep.pt'):\n", |
|
" wget(\"https://dl.fbaipublicfiles.com/slip/slip_large_100ep.pt\", model_path)\n", |
|
" sd = torch.load(f'{model_path}/slip_large_100ep.pt')\n", |
|
" real_sd = {}\n", |
|
" for k, v in sd['state_dict'].items():\n", |
|
" real_sd['.'.join(k.split('.')[1:])] = v\n", |
|
" del sd\n", |
|
" SLIPL16model.load_state_dict(real_sd)\n", |
|
" SLIPL16model.requires_grad_(False).eval().to(device)\n", |
|
"\n", |
|
" clip_models.append(SLIPL16model)\n", |
|
"\n", |
|
"normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])\n", |
|
"lpips_model = lpips.LPIPS(net='vgg').to(device)\n" |
|
], |
|
"outputs": [], |
|
"execution_count": null |
|
}, |
|
{ |
|
"cell_type": "markdown", |
|
"metadata": {}, |
|
"source": [ |
|
"# 3. Settings" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"metadata": {}, |
|
"source": [ |
|
"#@markdown ####**Basic Settings:**\n", |
|
"batch_name = 'TimeToDisco' #@param{type: 'string'}\n", |
|
"steps = 250 #@param [25,50,100,150,250,500,1000]{type: 'raw', allow-input: true}\n", |
|
"width_height = [1280, 768]#@param{type: 'raw'}\n", |
|
"clip_guidance_scale = 5000 #@param{type: 'number'}\n", |
|
"tv_scale = 0#@param{type: 'number'}\n", |
|
"range_scale = 150#@param{type: 'number'}\n", |
|
"sat_scale = 0#@param{type: 'number'}\n", |
|
"cutn_batches = 4 #@param{type: 'number'}\n", |
|
"skip_augs = False#@param{type: 'boolean'}\n", |
|
"\n", |
|
"#@markdown ---\n", |
|
"\n", |
|
"#@markdown ####**Init Settings:**\n", |
|
"init_image = None #@param{type: 'string'}\n", |
|
"init_scale = 1000 #@param{type: 'integer'}\n", |
|
"skip_steps = 10 #@param{type: 'integer'}\n", |
|
"#@markdown *Make sure you set skip_steps to ~50% of your steps if you want to use an init image.*\n", |
|
"\n", |
|
"#Get corrected sizes\n", |
|
"side_x = (width_height[0]//64)*64;\n", |
|
"side_y = (width_height[1]//64)*64;\n", |
|
"if side_x != width_height[0] or side_y != width_height[1]:\n", |
|
" print(f'Changing output size to {side_x}x{side_y}. Dimensions must by multiples of 64.')\n", |
|
"\n", |
|
"#Update Model Settings\n", |
|
"timestep_respacing = f'ddim{steps}'\n", |
|
"diffusion_steps = (1000//steps)*steps if steps < 1000 else steps\n", |
|
"model_config.update({\n", |
|
" 'timestep_respacing': timestep_respacing,\n", |
|
" 'diffusion_steps': diffusion_steps,\n", |
|
"})\n", |
|
"\n", |
|
"#Make folder for batch\n", |
|
"batchFolder = f'{outDirPath}/{batch_name}'\n", |
|
"createPath(batchFolder)\n" |
|
], |
|
"outputs": [], |
|
"execution_count": null |
|
}, |
|
{ |
|
"cell_type": "markdown", |
|
"metadata": { |
|
"id": "CnkTNXJAPzL2" |
|
}, |
|
"source": [ |
|
"### Animation Settings" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"metadata": {}, |
|
"source": [ |
|
"#@markdown ####**Animation Mode:**\n", |
|
"animation_mode = 'None' #@param ['None', '2D', '3D', 'Video Input'] {type:'string'}\n", |
|
"#@markdown *For animation, you probably want to turn `cutn_batches` to 1 to make it quicker.*\n", |
|
"\n", |
|
"\n", |
|
"#@markdown ---\n", |
|
"\n", |
|
"#@markdown ####**Video Input Settings:**\n", |
|
"if is_colab:\n", |
|
" video_init_path = \"/content/training.mp4\" #@param {type: 'string'}\n", |
|
"else:\n", |
|
" video_init_path = \"training.mp4\" #@param {type: 'string'}\n", |
|
"extract_nth_frame = 2 #@param {type:\"number\"} \n", |
|
"\n", |
|
"if animation_mode == \"Video Input\":\n", |
|
" if is_colab:\n", |
|
" videoFramesFolder = f'/content/videoFrames'\n", |
|
" else:\n", |
|
" videoFramesFolder = f'videoFrames'\n", |
|
" createPath(videoFramesFolder)\n", |
|
" print(f\"Exporting Video Frames (1 every {extract_nth_frame})...\")\n", |
|
" try:\n", |
|
" for f in pathlib.Path(f'{videoFramesFolder}').glob('*.jpg'):\n", |
|
" f.unlink()\n", |
|
" except:\n", |
|
" print('')\n", |
|
" vf = f'\"select=not(mod(n\\,{extract_nth_frame}))\"'\n", |
|
" subprocess.run(['ffmpeg', '-i', f'{video_init_path}', '-vf', f'{vf}', '-vsync', 'vfr', '-q:v', '2', '-loglevel', 'error', '-stats', f'{videoFramesFolder}/%04d.jpg'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", |
|
" #!ffmpeg -i {video_init_path} -vf {vf} -vsync vfr -q:v 2 -loglevel error -stats {videoFramesFolder}/%04d.jpg\n", |
|
"\n", |
|
"\n", |
|
"#@markdown ---\n", |
|
"\n", |
|
"#@markdown ####**2D Animation Settings:**\n", |
|
"#@markdown `zoom` is a multiplier of dimensions, 1 is no zoom.\n", |
|
"\n", |
|
"key_frames = True #@param {type:\"boolean\"}\n", |
|
"max_frames = 10000#@param {type:\"number\"}\n", |
|
"\n", |
|
"if animation_mode == \"Video Input\":\n", |
|
" max_frames = len(glob(f'{videoFramesFolder}/*.jpg'))\n", |
|
"\n", |
|
"interp_spline = 'Linear' #Do not change, currently will not look good. param ['Linear','Quadratic','Cubic']{type:\"string\"}\n", |
|
"angle = \"0:(0)\"#@param {type:\"string\"}\n", |
|
"zoom = \"0: (1), 10: (1.05)\"#@param {type:\"string\"}\n", |
|
"translation_x = \"0: (0)\"#@param {type:\"string\"}\n", |
|
"translation_y = \"0: (0)\"#@param {type:\"string\"}\n", |
|
"translation_z = \"0: (10.0)\"#@param {type:\"string\"}\n", |
|
"rotation_3d_x = \"0: (0)\"#@param {type:\"string\"}\n", |
|
"rotation_3d_y = \"0: (0)\"#@param {type:\"string\"}\n", |
|
"rotation_3d_z = \"0: (0)\"#@param {type:\"string\"}\n", |
|
"midas_depth_model = \"dpt_large\"#@param {type:\"string\"}\n", |
|
"midas_weight = 0.3#@param {type:\"number\"}\n", |
|
"near_plane = 200#@param {type:\"number\"}\n", |
|
"far_plane = 10000#@param {type:\"number\"}\n", |
|
"fov = 40#@param {type:\"number\"}\n", |
|
"padding_mode = 'border'#@param {type:\"string\"}\n", |
|
"sampling_mode = 'bicubic'#@param {type:\"string\"}\n", |
|
"\n", |
|
"#@markdown ---\n", |
|
"\n", |
|
"#@markdown ####**Coherency Settings:**\n", |
|
"#@markdown `frame_scale` tries to guide the new frame to looking like the old one. A good default is 1500.\n", |
|
"frames_scale = 1500 #@param{type: 'integer'}\n", |
|
"#@markdown `frame_skip_steps` will blur the previous frame - higher values will flicker less but struggle to add enough new detail to zoom into.\n", |
|
"frames_skip_steps = '60%' #@param ['40%', '50%', '60%', '70%', '80%'] {type: 'string'}\n", |
|
"\n", |
|
"\n", |
|
"def parse_key_frames(string, prompt_parser=None):\n", |
|
" \"\"\"Given a string representing frame numbers paired with parameter values at that frame,\n", |
|
" return a dictionary with the frame numbers as keys and the parameter values as the values.\n", |
|
"\n", |
|
" Parameters\n", |
|
" ----------\n", |
|
" string: string\n", |
|
" Frame numbers paired with parameter values at that frame number, in the format\n", |
|
" 'framenumber1: (parametervalues1), framenumber2: (parametervalues2), ...'\n", |
|
" prompt_parser: function or None, optional\n", |
|
" If provided, prompt_parser will be applied to each string of parameter values.\n", |
|
" \n", |
|
" Returns\n", |
|
" -------\n", |
|
" dict\n", |
|
" Frame numbers as keys, parameter values at that frame number as values\n", |
|
"\n", |
|
" Raises\n", |
|
" ------\n", |
|
" RuntimeError\n", |
|
" If the input string does not match the expected format.\n", |
|
" \n", |
|
" Examples\n", |
|
" --------\n", |
|
" >>> parse_key_frames(\"10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)\")\n", |
|
" {10: 'Apple: 1| Orange: 0', 20: 'Apple: 0| Orange: 1| Peach: 1'}\n", |
|
"\n", |
|
" >>> parse_key_frames(\"10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)\", prompt_parser=lambda x: x.lower()))\n", |
|
" {10: 'apple: 1| orange: 0', 20: 'apple: 0| orange: 1| peach: 1'}\n", |
|
" \"\"\"\n", |
|
" import re\n", |
|
" pattern = r'((?P<frame>[0-9]+):[\\s]*[\\(](?P<param>[\\S\\s]*?)[\\)])'\n", |
|
" frames = dict()\n", |
|
" for match_object in re.finditer(pattern, string):\n", |
|
" frame = int(match_object.groupdict()['frame'])\n", |
|
" param = match_object.groupdict()['param']\n", |
|
" if prompt_parser:\n", |
|
" frames[frame] = prompt_parser(param)\n", |
|
" else:\n", |
|
" frames[frame] = param\n", |
|
"\n", |
|
" if frames == {} and len(string) != 0:\n", |
|
" raise RuntimeError('Key Frame string not correctly formatted')\n", |
|
" return frames\n", |
|
"\n", |
|
"def get_inbetweens(key_frames, integer=False):\n", |
|
" \"\"\"Given a dict with frame numbers as keys and a parameter value as values,\n", |
|
" return a pandas Series containing the value of the parameter at every frame from 0 to max_frames.\n", |
|
" Any values not provided in the input dict are calculated by linear interpolation between\n", |
|
" the values of the previous and next provided frames. If there is no previous provided frame, then\n", |
|
" the value is equal to the value of the next provided frame, or if there is no next provided frame,\n", |
|
" then the value is equal to the value of the previous provided frame. If no frames are provided,\n", |
|
" all frame values are NaN.\n", |
|
"\n", |
|
" Parameters\n", |
|
" ----------\n", |
|
" key_frames: dict\n", |
|
" A dict with integer frame numbers as keys and numerical values of a particular parameter as values.\n", |
|
" integer: Bool, optional\n", |
|
" If True, the values of the output series are converted to integers.\n", |
|
" Otherwise, the values are floats.\n", |
|
" \n", |
|
" Returns\n", |
|
" -------\n", |
|
" pd.Series\n", |
|
" A Series with length max_frames representing the parameter values for each frame.\n", |
|
" \n", |
|
" Examples\n", |
|
" --------\n", |
|
" >>> max_frames = 5\n", |
|
" >>> get_inbetweens({1: 5, 3: 6})\n", |
|
" 0 5.0\n", |
|
" 1 5.0\n", |
|
" 2 5.5\n", |
|
" 3 6.0\n", |
|
" 4 6.0\n", |
|
" dtype: float64\n", |
|
"\n", |
|
" >>> get_inbetweens({1: 5, 3: 6}, integer=True)\n", |
|
" 0 5\n", |
|
" 1 5\n", |
|
" 2 5\n", |
|
" 3 6\n", |
|
" 4 6\n", |
|
" dtype: int64\n", |
|
" \"\"\"\n", |
|
" key_frame_series = pd.Series([np.nan for a in range(max_frames)])\n", |
|
"\n", |
|
" for i, value in key_frames.items():\n", |
|
" key_frame_series[i] = value\n", |
|
" key_frame_series = key_frame_series.astype(float)\n", |
|
" \n", |
|
" interp_method = interp_spline\n", |
|
"\n", |
|
" if interp_method == 'Cubic' and len(key_frames.items()) <=3:\n", |
|
" interp_method = 'Quadratic'\n", |
|
" \n", |
|
" if interp_method == 'Quadratic' and len(key_frames.items()) <= 2:\n", |
|
" interp_method = 'Linear'\n", |
|
" \n", |
|
" \n", |
|
" key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()]\n", |
|
" key_frame_series[max_frames-1] = key_frame_series[key_frame_series.last_valid_index()]\n", |
|
" # key_frame_series = key_frame_series.interpolate(method=intrp_method,order=1, limit_direction='both')\n", |
|
" key_frame_series = key_frame_series.interpolate(method=interp_method.lower(),limit_direction='both')\n", |
|
" if integer:\n", |
|
" return key_frame_series.astype(int)\n", |
|
" return key_frame_series\n", |
|
"\n", |
|
"def split_prompts(prompts):\n", |
|
" prompt_series = pd.Series([np.nan for a in range(max_frames)])\n", |
|
" for i, prompt in prompts.items():\n", |
|
" prompt_series[i] = prompt\n", |
|
" # prompt_series = prompt_series.astype(str)\n", |
|
" prompt_series = prompt_series.ffill().bfill()\n", |
|
" return prompt_series\n", |
|
"\n", |
|
"if key_frames:\n", |
|
" try:\n", |
|
" angle_series = get_inbetweens(parse_key_frames(angle))\n", |
|
" except RuntimeError as e:\n", |
|
" print(\n", |
|
" \"WARNING: You have selected to use key frames, but you have not \"\n", |
|
" \"formatted `angle` correctly for key frames.\\n\"\n", |
|
" \"Attempting to interpret `angle` as \"\n", |
|
" f'\"0: ({angle})\"\\n'\n", |
|
" \"Please read the instructions to find out how to use key frames \"\n", |
|
" \"correctly.\\n\"\n", |
|
" )\n", |
|
" angle = f\"0: ({angle})\"\n", |
|
" angle_series = get_inbetweens(parse_key_frames(angle))\n", |
|
"\n", |
|
" try:\n", |
|
" zoom_series = get_inbetweens(parse_key_frames(zoom))\n", |
|
" except RuntimeError as e:\n", |
|
" print(\n", |
|
" \"WARNING: You have selected to use key frames, but you have not \"\n", |
|
" \"formatted `zoom` correctly for key frames.\\n\"\n", |
|
" \"Attempting to interpret `zoom` as \"\n", |
|
" f'\"0: ({zoom})\"\\n'\n", |
|
" \"Please read the instructions to find out how to use key frames \"\n", |
|
" \"correctly.\\n\"\n", |
|
" )\n", |
|
" zoom = f\"0: ({zoom})\"\n", |
|
" zoom_series = get_inbetweens(parse_key_frames(zoom))\n", |
|
"\n", |
|
" try:\n", |
|
" translation_x_series = get_inbetweens(parse_key_frames(translation_x))\n", |
|
" except RuntimeError as e:\n", |
|
" print(\n", |
|
" \"WARNING: You have selected to use key frames, but you have not \"\n", |
|
" \"formatted `translation_x` correctly for key frames.\\n\"\n", |
|
" \"Attempting to interpret `translation_x` as \"\n", |
|
" f'\"0: ({translation_x})\"\\n'\n", |
|
" \"Please read the instructions to find out how to use key frames \"\n", |
|
" \"correctly.\\n\"\n", |
|
" )\n", |
|
" translation_x = f\"0: ({translation_x})\"\n", |
|
" translation_x_series = get_inbetweens(parse_key_frames(translation_x))\n", |
|
"\n", |
|
" try:\n", |
|
" translation_y_series = get_inbetweens(parse_key_frames(translation_y))\n", |
|
" except RuntimeError as e:\n", |
|
" print(\n", |
|
" \"WARNING: You have selected to use key frames, but you have not \"\n", |
|
" \"formatted `translation_y` correctly for key frames.\\n\"\n", |
|
" \"Attempting to interpret `translation_y` as \"\n", |
|
" f'\"0: ({translation_y})\"\\n'\n", |
|
" \"Please read the instructions to find out how to use key frames \"\n", |
|
" \"correctly.\\n\"\n", |
|
" )\n", |
|
" translation_y = f\"0: ({translation_y})\"\n", |
|
" translation_y_series = get_inbetweens(parse_key_frames(translation_y))\n", |
|
"\n", |
|
" try:\n", |
|
" translation_z_series = get_inbetweens(parse_key_frames(translation_z))\n", |
|
" except RuntimeError as e:\n", |
|
" print(\n", |
|
" \"WARNING: You have selected to use key frames, but you have not \"\n", |
|
" \"formatted `translation_z` correctly for key frames.\\n\"\n", |
|
" \"Attempting to interpret `translation_z` as \"\n", |
|
" f'\"0: ({translation_z})\"\\n'\n", |
|
" \"Please read the instructions to find out how to use key frames \"\n", |
|
" \"correctly.\\n\"\n", |
|
" )\n", |
|
" translation_z = f\"0: ({translation_z})\"\n", |
|
" translation_z_series = get_inbetweens(parse_key_frames(translation_z))\n", |
|
"\n", |
|
" try:\n", |
|
" rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x))\n", |
|
" except RuntimeError as e:\n", |
|
" print(\n", |
|
" \"WARNING: You have selected to use key frames, but you have not \"\n", |
|
" \"formatted `rotation_3d_x` correctly for key frames.\\n\"\n", |
|
" \"Attempting to interpret `rotation_3d_x` as \"\n", |
|
" f'\"0: ({rotation_3d_x})\"\\n'\n", |
|
" \"Please read the instructions to find out how to use key frames \"\n", |
|
" \"correctly.\\n\"\n", |
|
" )\n", |
|
" rotation_3d_x = f\"0: ({rotation_3d_x})\"\n", |
|
" rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x))\n", |
|
"\n", |
|
" try:\n", |
|
" rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y))\n", |
|
" except RuntimeError as e:\n", |
|
" print(\n", |
|
" \"WARNING: You have selected to use key frames, but you have not \"\n", |
|
" \"formatted `rotation_3d_y` correctly for key frames.\\n\"\n", |
|
" \"Attempting to interpret `rotation_3d_y` as \"\n", |
|
" f'\"0: ({rotation_3d_y})\"\\n'\n", |
|
" \"Please read the instructions to find out how to use key frames \"\n", |
|
" \"correctly.\\n\"\n", |
|
" )\n", |
|
" rotation_3d_y = f\"0: ({rotation_3d_y})\"\n", |
|
" rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y))\n", |
|
"\n", |
|
" try:\n", |
|
" rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z))\n", |
|
" except RuntimeError as e:\n", |
|
" print(\n", |
|
" \"WARNING: You have selected to use key frames, but you have not \"\n", |
|
" \"formatted `rotation_3d_z` correctly for key frames.\\n\"\n", |
|
" \"Attempting to interpret `rotation_3d_z` as \"\n", |
|
" f'\"0: ({rotation_3d_z})\"\\n'\n", |
|
" \"Please read the instructions to find out how to use key frames \"\n", |
|
" \"correctly.\\n\"\n", |
|
" )\n", |
|
" rotation_3d_z = f\"0: ({rotation_3d_z})\"\n", |
|
" rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z))\n", |
|
"\n", |
|
"else:\n", |
|
" angle = float(angle)\n", |
|
" zoom = float(zoom)\n", |
|
" translation_x = float(translation_x)\n", |
|
" translation_y = float(translation_y)\n", |
|
" translation_z = float(translation_z)\n", |
|
" rotation_3d_x = float(rotation_3d_x)\n", |
|
" rotation_3d_y = float(rotation_3d_y)\n", |
|
" rotation_3d_z = float(rotation_3d_z)\n" |
|
], |
|
"outputs": [], |
|
"execution_count": null |
|
}, |
|
{ |
|
"cell_type": "markdown", |
|
"metadata": { |
|
"id": "u1VHzHvNx5fd" |
|
}, |
|
"source": [ |
|
"### Extra Settings\n", |
|
" Partial Saves, Diffusion Sharpening, Advanced Settings, Cutn Scheduling" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"metadata": {}, |
|
"source": [ |
|
"#@markdown ####**Saving:**\n", |
|
"\n", |
|
"intermediate_saves = 0#@param{type: 'raw'}\n", |
|
"intermediates_in_subfolder = True #@param{type: 'boolean'}\n", |
|
"#@markdown Intermediate steps will save a copy at your specified intervals. You can either format it as a single integer or a list of specific steps \n", |
|
"\n", |
|
"#@markdown A value of `2` will save a copy at 33% and 66%. 0 will save none.\n", |
|
"\n", |
|
"#@markdown A value of `[5, 9, 34, 45]` will save at steps 5, 9, 34, and 45. (Make sure to include the brackets)\n", |
|
"\n", |
|
"\n", |
|
"if type(intermediate_saves) is not list:\n", |
|
" if intermediate_saves:\n", |
|
" steps_per_checkpoint = math.floor((steps - skip_steps - 1) // (intermediate_saves+1))\n", |
|
" steps_per_checkpoint = steps_per_checkpoint if steps_per_checkpoint > 0 else 1\n", |
|
" print(f'Will save every {steps_per_checkpoint} steps')\n", |
|
" else:\n", |
|
" steps_per_checkpoint = steps+10\n", |
|
"else:\n", |
|
" steps_per_checkpoint = None\n", |
|
"\n", |
|
"if intermediate_saves and intermediates_in_subfolder is True:\n", |
|
" partialFolder = f'{batchFolder}/partials'\n", |
|
" createPath(partialFolder)\n", |
|
"\n", |
|
" #@markdown ---\n", |
|
"\n", |
|
"#@markdown ####**SuperRes Sharpening:**\n", |
|
"#@markdown *Sharpen each image using latent-diffusion. Does not run in animation mode. `keep_unsharp` will save both versions.*\n", |
|
"sharpen_preset = 'Off' #@param ['Off', 'Faster', 'Fast', 'Slow', 'Very Slow']\n", |
|
"keep_unsharp = True #@param{type: 'boolean'}\n", |
|
"\n", |
|
"if sharpen_preset != 'Off' and keep_unsharp is True:\n", |
|
" unsharpenFolder = f'{batchFolder}/unsharpened'\n", |
|
" createPath(unsharpenFolder)\n", |
|
"\n", |
|
"\n", |
|
" #@markdown ---\n", |
|
"\n", |
|
"#@markdown ####**Advanced Settings:**\n", |
|
"#@markdown *There are a few extra advanced settings available if you double click this cell.*\n", |
|
"\n", |
|
"#@markdown *Perlin init will replace your init, so uncheck if using one.*\n", |
|
"\n", |
|
"perlin_init = False #@param{type: 'boolean'}\n", |
|
"perlin_mode = 'mixed' #@param ['mixed', 'color', 'gray']\n", |
|
"set_seed = 'random_seed' #@param{type: 'string'}\n", |
|
"eta = 0.8#@param{type: 'number'}\n", |
|
"clamp_grad = True #@param{type: 'boolean'}\n", |
|
"clamp_max = 0.05 #@param{type: 'number'}\n", |
|
"\n", |
|
"\n", |
|
"### EXTRA ADVANCED SETTINGS:\n", |
|
"randomize_class = True\n", |
|
"clip_denoised = False\n", |
|
"fuzzy_prompt = False\n", |
|
"rand_mag = 0.05\n", |
|
"\n", |
|
"\n", |
|
" #@markdown ---\n", |
|
"\n", |
|
"#@markdown ####**Cutn Scheduling:**\n", |
|
"#@markdown Format: `[40]*400+[20]*600` = 40 cuts for the first 400 /1000 steps, then 20 for the last 600/1000\n", |
|
"\n", |
|
"#@markdown cut_overview and cut_innercut are cumulative for total cutn on any given step. Overview cuts see the entire image and are good for early structure, innercuts are your standard cutn.\n", |
|
"\n", |
|
"cut_overview = \"[12]*400+[4]*600\" #@param {type: 'string'} \n", |
|
"cut_innercut =\"[4]*400+[12]*600\"#@param {type: 'string'} \n", |
|
"cut_ic_pow = 1#@param {type: 'number'} \n", |
|
"cut_icgray_p = \"[0.2]*400+[0]*600\"#@param {type: 'string'}\n" |
|
], |
|
"outputs": [], |
|
"execution_count": null |
|
}, |
|
{ |
|
"cell_type": "markdown", |
|
"metadata": {}, |
|
"source": [ |
|
"### Prompts\n", |
|
"`animation_mode: None` will only use the first set. `animation_mode: 2D / Video` will run through them per the set frames and hold on the last one." |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"metadata": {}, |
|
"source": [ |
|
"text_prompts = {\n", |
|
" 0: [\"A beautiful painting of a singular lighthouse, shining its light across a tumultuous sea of blood by greg rutkowski and thomas kinkade, Trending on artstation.\", \"yellow color scheme\"],\n", |
|
" 100: [\"This set of prompts start at frame 100\",\"This prompt has weight five:5\"],\n", |
|
"}\n", |
|
"\n", |
|
"image_prompts = {\n", |
|
" # 0:['ImagePromptsWorkButArentVeryGood.png:2',],\n", |
|
"}\n" |
|
], |
|
"outputs": [], |
|
"execution_count": null |
|
}, |
|
{ |
|
"cell_type": "markdown", |
|
"metadata": {}, |
|
"source": [ |
|
"# 4. Diffuse!" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"metadata": {}, |
|
"source": [ |
|
"#@title Do the Run!\n", |
|
"#@markdown `n_batches` ignored with animation modes.\n", |
|
"display_rate = 50 #@param{type: 'number'}\n", |
|
"n_batches = 50 #@param{type: 'number'}\n", |
|
"\n", |
|
"#Update Model Settings\n", |
|
"timestep_respacing = f'ddim{steps}'\n", |
|
"diffusion_steps = (1000//steps)*steps if steps < 1000 else steps\n", |
|
"model_config.update({\n", |
|
" 'timestep_respacing': timestep_respacing,\n", |
|
" 'diffusion_steps': diffusion_steps,\n", |
|
"})\n", |
|
"\n", |
|
"batch_size = 1 \n", |
|
"\n", |
|
"def move_files(start_num, end_num, old_folder, new_folder):\n", |
|
" for i in range(start_num, end_num):\n", |
|
" old_file = old_folder + f'/{batch_name}({batchNum})_{i:04}.png'\n", |
|
" new_file = new_folder + f'/{batch_name}({batchNum})_{i:04}.png'\n", |
|
" os.rename(old_file, new_file)\n", |
|
"\n", |
|
"#@markdown ---\n", |
|
"\n", |
|
"\n", |
|
"resume_run = False #@param{type: 'boolean'}\n", |
|
"run_to_resume = 'latest' #@param{type: 'string'}\n", |
|
"resume_from_frame = 'latest' #@param{type: 'string'}\n", |
|
"retain_overwritten_frames = False #@param{type: 'boolean'}\n", |
|
"if retain_overwritten_frames is True:\n", |
|
" retainFolder = f'{batchFolder}/retained'\n", |
|
" createPath(retainFolder)\n", |
|
"\n", |
|
"\n", |
|
"skip_step_ratio = int(frames_skip_steps.rstrip(\"%\")) / 100\n", |
|
"calc_frames_skip_steps = math.floor(steps * skip_step_ratio)\n", |
|
"\n", |
|
"\n", |
|
"if steps <= calc_frames_skip_steps:\n", |
|
" sys.exit(\"ERROR: You can't skip more steps than your total steps\")\n", |
|
"\n", |
|
"if resume_run:\n", |
|
" if run_to_resume == 'latest':\n", |
|
" try:\n", |
|
" batchNum\n", |
|
" except:\n", |
|
" batchNum = len(glob(f\"{batchFolder}/{batch_name}(*)_settings.txt\"))-1\n", |
|
" else:\n", |
|
" batchNum = int(run_to_resume)\n", |
|
" if resume_from_frame == 'latest':\n", |
|
" start_frame = len(glob(batchFolder+f\"/{batch_name}({batchNum})_*.png\"))\n", |
|
" else:\n", |
|
" start_frame = int(resume_from_frame)+1\n", |
|
" if retain_overwritten_frames is True:\n", |
|
" existing_frames = len(glob(batchFolder+f\"/{batch_name}({batchNum})_*.png\"))\n", |
|
" frames_to_save = existing_frames - start_frame\n", |
|
" print(f'Moving {frames_to_save} frames to the Retained folder')\n", |
|
" move_files(start_frame, existing_frames, batchFolder, retainFolder)\n", |
|
"else:\n", |
|
" start_frame = 0\n", |
|
" batchNum = len(glob(batchFolder+\"/*.txt\"))\n", |
|
" while path.isfile(f\"{batchFolder}/{batch_name}({batchNum})_settings.txt\") is True or path.isfile(f\"{batchFolder}/{batch_name}-{batchNum}_settings.txt\") is True:\n", |
|
" batchNum += 1\n", |
|
"\n", |
|
"print(f'Starting Run: {batch_name}({batchNum}) at frame {start_frame}')\n", |
|
"\n", |
|
"if set_seed == 'random_seed':\n", |
|
" random.seed()\n", |
|
" seed = random.randint(0, 2**32)\n", |
|
" # print(f'Using seed: {seed}')\n", |
|
"else:\n", |
|
" seed = int(set_seed)\n", |
|
"\n", |
|
"args = {\n", |
|
" 'batchNum': batchNum,\n", |
|
" 'prompts_series':split_prompts(text_prompts) if text_prompts else None,\n", |
|
" 'image_prompts_series':split_prompts(image_prompts) if image_prompts else None,\n", |
|
" 'seed': seed,\n", |
|
" 'display_rate':display_rate,\n", |
|
" 'n_batches':n_batches if animation_mode == 'None' else 1,\n", |
|
" 'batch_size':batch_size,\n", |
|
" 'batch_name': batch_name,\n", |
|
" 'steps': steps,\n", |
|
" 'sampling_mode': sampling_mode,\n", |
|
" 'width_height': width_height,\n", |
|
" 'clip_guidance_scale': clip_guidance_scale,\n", |
|
" 'tv_scale': tv_scale,\n", |
|
" 'range_scale': range_scale,\n", |
|
" 'sat_scale': sat_scale,\n", |
|
" 'cutn_batches': cutn_batches,\n", |
|
" 'init_image': init_image,\n", |
|
" 'init_scale': init_scale,\n", |
|
" 'skip_steps': skip_steps,\n", |
|
" 'sharpen_preset': sharpen_preset,\n", |
|
" 'keep_unsharp': keep_unsharp,\n", |
|
" 'side_x': side_x,\n", |
|
" 'side_y': side_y,\n", |
|
" 'timestep_respacing': timestep_respacing,\n", |
|
" 'diffusion_steps': diffusion_steps,\n", |
|
" 'animation_mode': animation_mode,\n", |
|
" 'video_init_path': video_init_path,\n", |
|
" 'extract_nth_frame': extract_nth_frame,\n", |
|
" 'key_frames': key_frames,\n", |
|
" 'max_frames': max_frames if animation_mode != \"None\" else 1,\n", |
|
" 'interp_spline': interp_spline,\n", |
|
" 'start_frame': start_frame,\n", |
|
" 'angle': angle,\n", |
|
" 'zoom': zoom,\n", |
|
" 'translation_x': translation_x,\n", |
|
" 'translation_y': translation_y,\n", |
|
" 'translation_z': translation_z,\n", |
|
" 'rotation_3d_x': rotation_3d_x,\n", |
|
" 'rotation_3d_y': rotation_3d_y,\n", |
|
" 'rotation_3d_z': rotation_3d_z,\n", |
|
" 'midas_depth_model': midas_depth_model,\n", |
|
" 'midas_weight': midas_weight,\n", |
|
" 'near_plane': near_plane,\n", |
|
" 'far_plane': far_plane,\n", |
|
" 'fov': fov,\n", |
|
" 'padding_mode': padding_mode,\n", |
|
" 'sampling_mode': sampling_mode,\n", |
|
" 'angle_series':angle_series,\n", |
|
" 'zoom_series':zoom_series,\n", |
|
" 'translation_x_series':translation_x_series,\n", |
|
" 'translation_y_series':translation_y_series,\n", |
|
" 'translation_z_series':translation_z_series,\n", |
|
" 'rotation_3d_x_series':rotation_3d_x_series,\n", |
|
" 'rotation_3d_y_series':rotation_3d_y_series,\n", |
|
" 'rotation_3d_z_series':rotation_3d_z_series,\n", |
|
" 'frames_scale': frames_scale,\n", |
|
" 'calc_frames_skip_steps': calc_frames_skip_steps,\n", |
|
" 'skip_step_ratio': skip_step_ratio,\n", |
|
" 'calc_frames_skip_steps': calc_frames_skip_steps,\n", |
|
" 'text_prompts': text_prompts,\n", |
|
" 'image_prompts': image_prompts,\n", |
|
" 'cut_overview': eval(cut_overview),\n", |
|
" 'cut_innercut': eval(cut_innercut),\n", |
|
" 'cut_ic_pow': cut_ic_pow,\n", |
|
" 'cut_icgray_p': eval(cut_icgray_p),\n", |
|
" 'intermediate_saves': intermediate_saves,\n", |
|
" 'intermediates_in_subfolder': intermediates_in_subfolder,\n", |
|
" 'steps_per_checkpoint': steps_per_checkpoint,\n", |
|
" 'perlin_init': perlin_init,\n", |
|
" 'perlin_mode': perlin_mode,\n", |
|
" 'set_seed': set_seed,\n", |
|
" 'eta': eta,\n", |
|
" 'clamp_grad': clamp_grad,\n", |
|
" 'clamp_max': clamp_max,\n", |
|
" 'skip_augs': skip_augs,\n", |
|
" 'randomize_class': randomize_class,\n", |
|
" 'clip_denoised': clip_denoised,\n", |
|
" 'fuzzy_prompt': fuzzy_prompt,\n", |
|
" 'rand_mag': rand_mag,\n", |
|
"}\n", |
|
"\n", |
|
"args = SimpleNamespace(**args)\n", |
|
"\n", |
|
"print('Prepping model...')\n", |
|
"model, diffusion = create_model_and_diffusion(**model_config)\n", |
|
"model.load_state_dict(torch.load(f'{model_path}/{diffusion_model}.pt', map_location='cpu'))\n", |
|
"model.requires_grad_(False).eval().to(device)\n", |
|
"for name, param in model.named_parameters():\n", |
|
" if 'qkv' in name or 'norm' in name or 'proj' in name:\n", |
|
" param.requires_grad_()\n", |
|
"if model_config['use_fp16']:\n", |
|
" model.convert_to_fp16()\n", |
|
"\n", |
|
"gc.collect()\n", |
|
"torch.cuda.empty_cache()\n", |
|
"try:\n", |
|
" do_run()\n", |
|
"except KeyboardInterrupt:\n", |
|
" pass\n", |
|
"finally:\n", |
|
" print('Seed used:', seed)\n", |
|
" gc.collect()\n", |
|
" torch.cuda.empty_cache()\n" |
|
], |
|
"outputs": [], |
|
"execution_count": null |
|
}, |
|
{ |
|
"cell_type": "markdown", |
|
"metadata": {}, |
|
"source": [ |
|
"# 5. Create the video" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"metadata": {}, |
|
"source": [ |
|
"# @title ### **Create video**\n", |
|
"#@markdown Video file will save in the same folder as your images.\n", |
|
"\n", |
|
"skip_video_for_run_all = True #@param {type: 'boolean'}\n", |
|
"\n", |
|
"if skip_video_for_run_all == True:\n", |
|
" print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it')\n", |
|
"\n", |
|
"else:\n", |
|
" # import subprocess in case this cell is run without the above cells\n", |
|
" import subprocess\n", |
|
" from base64 import b64encode\n", |
|
"\n", |
|
" latest_run = batchNum\n", |
|
"\n", |
|
" folder = batch_name #@param\n", |
|
" run = latest_run #@param\n", |
|
" final_frame = 'final_frame'\n", |
|
"\n", |
|
"\n", |
|
" init_frame = 1#@param {type:\"number\"} This is the frame where the video will start\n", |
|
" last_frame = final_frame#@param {type:\"number\"} You can change i to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist.\n", |
|
" fps = 12#@param {type:\"number\"}\n", |
|
" # view_video_in_cell = True #@param {type: 'boolean'}\n", |
|
"\n", |
|
" frames = []\n", |
|
" # tqdm.write('Generating video...')\n", |
|
"\n", |
|
" if last_frame == 'final_frame':\n", |
|
" last_frame = len(glob(batchFolder+f\"/{folder}({run})_*.png\"))\n", |
|
" print(f'Total frames: {last_frame}')\n", |
|
"\n", |
|
" image_path = f\"{outDirPath}/{folder}/{folder}({run})_%04d.png\"\n", |
|
" filepath = f\"{outDirPath}/{folder}/{folder}({run}).mp4\"\n", |
|
"\n", |
|
"\n", |
|
" cmd = [\n", |
|
" 'ffmpeg',\n", |
|
" '-y',\n", |
|
" '-vcodec',\n", |
|
" 'png',\n", |
|
" '-r',\n", |
|
" str(fps),\n", |
|
" '-start_number',\n", |
|
" str(init_frame),\n", |
|
" '-i',\n", |
|
" image_path,\n", |
|
" '-frames:v',\n", |
|
" str(last_frame+1),\n", |
|
" '-c:v',\n", |
|
" 'libx264',\n", |
|
" '-vf',\n", |
|
" f'fps={fps}',\n", |
|
" '-pix_fmt',\n", |
|
" 'yuv420p',\n", |
|
" '-crf',\n", |
|
" '17',\n", |
|
" '-preset',\n", |
|
" 'veryslow',\n", |
|
" filepath\n", |
|
" ]\n", |
|
"\n", |
|
" process = subprocess.Popen(cmd, cwd=f'{batchFolder}', stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n", |
|
" stdout, stderr = process.communicate()\n", |
|
" if process.returncode != 0:\n", |
|
" print(stderr)\n", |
|
" raise RuntimeError(stderr)\n", |
|
" else:\n", |
|
" print(\"The video is ready and saved to the images folder\")\n", |
|
"\n", |
|
" # if view_video_in_cell:\n", |
|
" # mp4 = open(filepath,'rb').read()\n", |
|
" # data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", |
|
" # display.HTML(f'<video width=400 controls><source src=\"{data_url}\" type=\"video/mp4\"></video>')" |
|
], |
|
"outputs": [], |
|
"execution_count": null |
|
} |
|
], |
|
"metadata": { |
|
"anaconda-cloud": {}, |
|
"accelerator": "GPU", |
|
"colab": { |
|
"collapsed_sections": [ |
|
"1YwMUyt9LHG1", |
|
"XTu6AjLyFQUq", |
|
"_9Eg9Kf5FlfK", |
|
"CnkTNXJAPzL2", |
|
"u1VHzHvNx5fd" |
|
], |
|
"machine_shape": "hm", |
|
"name": "Disco Diffusion v5 [w/ 3D animation]", |
|
"private_outputs": true, |
|
"provenance": [], |
|
"include_colab_link": true |
|
}, |
|
"kernelspec": { |
|
"display_name": "Python 3", |
|
"language": "python", |
|
"name": "python3" |
|
}, |
|
"language_info": { |
|
"codemirror_mode": { |
|
"name": "ipython", |
|
"version": 3 |
|
}, |
|
"file_extension": ".py", |
|
"mimetype": "text/x-python", |
|
"name": "python", |
|
"nbconvert_exporter": "python", |
|
"pygments_lexer": "ipython3", |
|
"version": "3.6.1" |
|
} |
|
}, |
|
"nbformat": 4, |
|
"nbformat_minor": 4 |
|
} |