VR modifications for stereo 180 frames and spherical projection in utils

pull/41/head
Tom Mason 3 years ago
parent b99492d4e7
commit e4754e6a42
  1. 68
      Disco_Diffusion.ipynb
  2. 20
      disco_xform_utils.py

@ -286,13 +286,13 @@
"source": [
"**Diffusion settings (Defaults are heavily outdated)**\n",
"---\n",
"\n",
"Disco Diffusion is complex, and continually evolving with new features. The most current documentation on on Disco Diffusion settings can be found in the unofficial guidebook:\n",
"\n",
"\n",
"Disco Diffusion is complex, and continually evolving with new features. The most current documentation on on Disco Diffusion settings can be found in the unofficial guidebook:\n",
"\n",
"[Zippy's Disco Diffusion Cheatsheet](https://docs.google.com/document/d/1l8s7uS2dGqjztYSjPpzlmXLjl5PM3IGkRWI3IiCuK7g/edit)\n",
"\n",
"We also encourage users to join the [Disco Diffusion User Discord](https://discord.gg/XGZrFFCRfN) to learn from the active user community.",
"\n",
"We also encourage users to join the [Disco Diffusion User Discord](https://discord.gg/XGZrFFCRfN) to learn from the active user community.\n",
"\n",
"This section below is outdated as of v2\n",
"\n",
"Setting | Description | Default\n",
@ -1423,11 +1423,30 @@
" cv2.imwrite(f'{batchFolder}/{filename}',blendedImage)\n",
" else:\n",
" image.save(f'{batchFolder}/{filename}')\n",
" \n",
" if vr_mode:\n",
" generate_eye_views(trans_scale,batchFolder,filename,frame_num,midas_model, midas_transform)\n",
" \n",
" # if frame_num != args.max_frames-1:\n",
" # display.clear_output()\n",
" \n",
" plt.plot(np.array(loss_values), 'r')\n",
"\n",
"def generate_eye_views(trans_scale,batchFolder,filename,frame_num,midas_model, midas_transform):\n",
" for i in range(2):\n",
" theta = vr_eye_angle * (math.pi/180)\n",
" ray_origin = math.cos(theta) * vr_ipd / 2 * (-1.0 if i==0 else 1.0)\n",
" ray_rotation = (theta if i==0 else -theta)\n",
" translate_xyz = [-(ray_origin)*trans_scale, 0,0]\n",
" rotate_xyz = [0, (ray_rotation), 0]\n",
" rot_mat = p3dT.euler_angles_to_matrix(torch.tensor(rotate_xyz, device=device), \"XYZ\").unsqueeze(0)\n",
" transformed_image = dxf.transform_image_3d(f'{batchFolder}/{filename}', midas_model, midas_transform, DEVICE,\n",
" rot_mat, translate_xyz, args.near_plane, args.far_plane,\n",
" args.fov, padding_mode=args.padding_mode,\n",
" sampling_mode=args.sampling_mode, midas_weight=args.midas_weight,spherical=True)\n",
" eye_file_path = batchFolder+f\"/frame_{frame_num-1:04}\" + ('_l' if i==0 else '_r')+'.png'\n",
" transformed_image.save(eye_file_path)\n",
"\n",
"def save_settings():\n",
" setting_list = {\n",
" 'text_prompts': text_prompts,\n",
@ -1693,9 +1712,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ModelSettings"
},
"outputs": [],
"source": [
"#@markdown ####**Models Settings:**\n",
"diffusion_model = \"512x512_diffusion_uncond_finetune_008100\" #@param [\"256x256_diffusion_uncond\", \"512x512_diffusion_uncond_finetune_008100\"]\n",
@ -2008,6 +2029,31 @@
"frames_skip_steps = '60%' #@param ['40%', '50%', '60%', '70%', '80%'] {type: 'string'}\n",
"\n",
"\n",
"#======= VR MODE\n",
"#@markdown ---\n",
"#@markdown ####**VR Mode (3D anim only):**\n",
"#@markdown Enables stereo rendering of left/right eye views (supporting Turbo) which use a different (fish-eye) camera projection matrix. \n",
"#@markdown Note the images you're prompting will work better if they have some inherent wide-angle aspect\n",
"#@markdown The generated images will need to be combined into left/right videos. These can then be stitched into the VR180 format.\n",
"#@markdown Google made the VR180 Creator tool but subsequently stopped supporting it. It's available for download in a few places including https://www.patrickgrunwald.de/vr180-creator-download\n",
"#@markdown The tool is not only good for stitching (videos and photos) but also for adding the correct metadata into existing videos, which is needed for services like YouTube to identify the format correctly.\n",
"#@markdown Watching YouTube VR videos isn't necessarily the easiest depending on your headset. For instance Oculus have a dedicated media studio and store which makes the files easier to access on a Quest https://creator.oculus.com/manage/mediastudio/\n",
"#@markdown The command to get ffmpeg to concat your frames for each eye is in the form: ffmpeg -framerate 15 -i frame_%4d_l.png l.mp4 (repeat for r)",
"\n",
"vr_mode = False\n",
"vr_eye_angle:0.5, #@param{type: 'number'}\n",
"#@markdown eye_angle is the y-axis rotation of the eyes towards the center\n",
"vr_ipd:5.0,\n",
"#@markdown interpupillary distance (between the eyes)\n",
" \n",
"#insist turbo be used only w 3d anim.\n",
"if vr_mode and animation_mode != '3D':\n",
" print('=====')\n",
" print('VR mode only available with 3D animations. Disabling VR.')\n",
" print('=====')\n",
" turbo_mode = False\n",
"\n",
"\n",
"def parse_key_frames(string, prompt_parser=None):\n",
" \"\"\"Given a string representing frame numbers paired with parameter values at that frame,\n",
" return a dictionary with the frame numbers as keys and the parameter values as the values.\n",
@ -2341,9 +2387,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Prompts"
},
"outputs": [],
"source": [
"text_prompts = {\n",
" 0: [\"A beautiful painting of a singular lighthouse, shining its light across a tumultuous sea of blood by greg rutkowski and thomas kinkade, Trending on artstation.\", \"yellow color scheme\"],\n",
@ -2368,9 +2416,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DoTheRun"
},
"outputs": [],
"source": [
"#@title Do the Run!\n",
"#@markdown `n_batches` ignored with animation modes.\n",
@ -2566,9 +2616,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CreateVid"
},
"outputs": [],
"source": [
"# @title ### **Create video**\n",
"#@markdown Video file will save in the same folder as your images.\n",
@ -2651,8 +2703,8 @@
}
],
"metadata": {
"anaconda-cloud": {},
"accelerator": "GPU",
"anaconda-cloud": {},
"colab": {
"collapsed_sections": [
"CreditsChTop",
@ -2666,11 +2718,11 @@
"AnimSetTop",
"ExtraSetTop"
],
"include_colab_link": true,
"machine_shape": "hm",
"name": "Disco Diffusion v5.1 [w/ Turbo]",
"private_outputs": true,
"provenance": [],
"include_colab_link": true
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",

@ -15,7 +15,7 @@ MAX_ADABINS_AREA = 500000
MIN_ADABINS_AREA = 448*448
@torch.no_grad()
def transform_image_3d(img_filepath, midas_model, midas_transform, device, rot_mat=torch.eye(3).unsqueeze(0), translate=(0.,0.,-0.04), near=2000, far=20000, fov_deg=60, padding_mode='border', sampling_mode='bicubic', midas_weight = 0.3):
def transform_image_3d(img_filepath, midas_model, midas_transform, device, rot_mat=torch.eye(3).unsqueeze(0), translate=(0.,0.,-0.04), near=2000, far=20000, fov_deg=60, padding_mode='border', sampling_mode='bicubic', midas_weight = 0.3,spherical=False):
img_pil = Image.open(open(img_filepath, 'rb')).convert('RGB')
w, h = img_pil.size
image_tensor = torchvision.transforms.functional.to_tensor(img_pil).to(device)
@ -107,9 +107,25 @@ def transform_image_3d(img_filepath, midas_model, midas_transform, device, rot_m
# coords_2d will have shape (N,H,W,2).. which is also what grid_sample needs.
coords_2d = torch.nn.functional.affine_grid(identity_2d_batch, [1,1,h,w], align_corners=False)
offset_coords_2d = coords_2d - torch.reshape(offset_xy, (h,w,2)).unsqueeze(0)
new_image = torch.nn.functional.grid_sample(image_tensor.add(1/512 - 0.0001).unsqueeze(0), offset_coords_2d, mode=sampling_mode, padding_mode=padding_mode, align_corners=False)
if spherical:
spherical_grid = get_spherical_projection(h, w, torch.tensor([0,0], device=device), -0.4,device=device)#align_corners=False
stage_image = torch.nn.functional.grid_sample(image_tensor.add(1/512 - 0.0001).unsqueeze(0), offset_coords_2d, mode=sampling_mode, padding_mode=padding_mode, align_corners=True)
new_image = torch.nn.functional.grid_sample(stage_image, spherical_grid,align_corners=True) #, mode=sampling_mode, padding_mode=padding_mode, align_corners=False)
else:
new_image = torch.nn.functional.grid_sample(image_tensor.add(1/512 - 0.0001).unsqueeze(0), offset_coords_2d, mode=sampling_mode, padding_mode=padding_mode, align_corners=False)
img_pil = torchvision.transforms.ToPILImage()(new_image.squeeze().clamp(0,1.))
torch.cuda.empty_cache()
return img_pil
def get_spherical_projection(H, W, center, magnitude,device):
xx, yy = torch.linspace(-1, 1, W,dtype=torch.float32,device=device), torch.linspace(-1, 1, H,dtype=torch.float32,device=device)
gridy, gridx = torch.meshgrid(yy, xx)
grid = torch.stack([gridx, gridy], dim=-1)
d = center - grid
d_sum = torch.sqrt((d**2).sum(axis=-1))
grid += d * d_sum.unsqueeze(-1) * magnitude
return grid.unsqueeze(0)
Loading…
Cancel
Save