From d131bec014f1793f3b5e764644aa05d13dbaafc1 Mon Sep 17 00:00:00 2001 From: David Marx Date: Tue, 28 Feb 2023 12:35:28 -0800 Subject: [PATCH 01/14] add checkpoint_every setting --- Stable_Diffusion_KLMC2_Animation.ipynb | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Stable_Diffusion_KLMC2_Animation.ipynb b/Stable_Diffusion_KLMC2_Animation.ipynb index 41015e2..c19417f 100644 --- a/Stable_Diffusion_KLMC2_Animation.ipynb +++ b/Stable_Diffusion_KLMC2_Animation.ipynb @@ -580,6 +580,7 @@ " resume_from=-1,\n", " img_init_steps=None,\n", " stuff_to_plot=None,\n", + " checkpoint_every=10,\n", "):\n", "\n", " if stuff_to_plot is None:\n", @@ -689,7 +690,7 @@ " extra_args,\n", " )\n", "\n", - " save_checkpoint = (i % 10) == 0\n", + " save_checkpoint = (i % checkpoint_every) == 0\n", " if save_checkpoint:\n", " ex.submit(write_klmc2_state, v=v, x=x, i=i)\n", " logger.debug(settings[i])\n", @@ -1197,6 +1198,7 @@ "# @markdown `fake` is very fast and low memory but inaccurate. `zero` (fallback to first order KLMC) is not recommended.\n", "hvp_method = 'fake' # @param [\"forward-functorch\", \"reverse\", \"fake\", \"zero\"]\n", "\n", + "checkpoint_every = 10 # @param {type:\"number\"}\n", "\n", "###########################\n", "\n", @@ -1404,6 +1406,7 @@ " resume_from=resume_from,\n", " img_init_steps=img_init_steps,\n", " stuff_to_plot=stuff_to_plot,\n", + " checkpoint_every=checkpoint_every,\n", ")\n" ] }, From 1025ff7376ce86eed6276110203b19ac8c6a24e0 Mon Sep 17 00:00:00 2001 From: David Marx Date: Tue, 28 Feb 2023 12:46:33 -0800 Subject: [PATCH 02/14] settings resume --- Stable_Diffusion_KLMC2_Animation.ipynb | 30 ++++++++++++++++++-------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/Stable_Diffusion_KLMC2_Animation.ipynb b/Stable_Diffusion_KLMC2_Animation.ipynb index c19417f..d6d763e 100644 --- a/Stable_Diffusion_KLMC2_Animation.ipynb +++ b/Stable_Diffusion_KLMC2_Animation.ipynb @@ -631,6 +631,8 @@ " state = read_klmc2_state(latest_frame=resume_from)\n", " if state:\n", " x, v, i_resume = state['x'], state['v'], state['i']\n", + " # to do: resumption of settings\n", + " settings_i = state['settings_i']\n", " \n", " # to do: use multicond for init image\n", " # we want this test after resumption if resuming\n", @@ -653,14 +655,23 @@ " # fast-forward loop to resumption index\n", " if resume and i < i_resume:\n", " continue\n", - "\n", - " h = settings[i]['h']\n", - " gamma = settings[i]['gamma']\n", - " alpha = settings[i]['alpha']\n", - " tau = settings[i]['tau']\n", - " g = settings[i]['g']\n", - " sigma = settings[i]['sigma']\n", - " steps = settings[i]['steps']\n", + " if resume and (i == i_resume):\n", + " # should these values be written into settings[i]?\n", + " h = settings_i['h']\n", + " gamma = settings_i['gamma']\n", + " alpha = settings_i['alpha']\n", + " tau = settings_i['tau']\n", + " g = settings_i['g']\n", + " sigma = settings_i['sigma']\n", + " steps = settings_i['steps']\n", + " else:\n", + " h = settings[i]['h']\n", + " gamma = settings[i]['gamma']\n", + " alpha = settings[i]['alpha']\n", + " tau = settings[i]['tau']\n", + " g = settings[i]['g']\n", + " sigma = settings[i]['sigma']\n", + " steps = settings[i]['steps']\n", "\n", " h = torch.tensor(h, device=x.device)\n", " gamma = torch.tensor(gamma, device=x.device)\n", @@ -692,7 +703,8 @@ "\n", " save_checkpoint = (i % checkpoint_every) == 0\n", " if save_checkpoint:\n", - " ex.submit(write_klmc2_state, v=v, x=x, i=i)\n", + " settings_i = settings[i]\n", + " ex.submit(write_klmc2_state, v=v, x=x, i=i, settings_i=settings_i)\n", " logger.debug(settings[i])\n", "\n", "\n", From fe29fb5a087d0d5990cc49939fadbc238e1e0fac Mon Sep 17 00:00:00 2001 From: David Marx Date: Tue, 28 Feb 2023 13:05:15 -0800 Subject: [PATCH 03/14] resume by writing param vals into settings obj --- Stable_Diffusion_KLMC2_Animation.ipynb | 41 +++++++++++++++----------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/Stable_Diffusion_KLMC2_Animation.ipynb b/Stable_Diffusion_KLMC2_Animation.ipynb index d6d763e..cf78880 100644 --- a/Stable_Diffusion_KLMC2_Animation.ipynb +++ b/Stable_Diffusion_KLMC2_Animation.ipynb @@ -633,6 +633,13 @@ " x, v, i_resume = state['x'], state['v'], state['i']\n", " # to do: resumption of settings\n", " settings_i = state['settings_i']\n", + " settings[i]['h'] = settings_i['h']\n", + " settings[i]['gamma'] = settings_i['gamma']\n", + " settings[i]['alpha'] = settings_i['alpha']\n", + " settings[i]['tau'] = settings_i['tau']\n", + " settings[i]['g'] = settings_i['g']\n", + " settings[i]['sigma'] = settings_i['sigma']\n", + " settings[i]['steps'] = settings_i['steps']\n", " \n", " # to do: use multicond for init image\n", " # we want this test after resumption if resuming\n", @@ -655,23 +662,23 @@ " # fast-forward loop to resumption index\n", " if resume and i < i_resume:\n", " continue\n", - " if resume and (i == i_resume):\n", - " # should these values be written into settings[i]?\n", - " h = settings_i['h']\n", - " gamma = settings_i['gamma']\n", - " alpha = settings_i['alpha']\n", - " tau = settings_i['tau']\n", - " g = settings_i['g']\n", - " sigma = settings_i['sigma']\n", - " steps = settings_i['steps']\n", - " else:\n", - " h = settings[i]['h']\n", - " gamma = settings[i]['gamma']\n", - " alpha = settings[i]['alpha']\n", - " tau = settings[i]['tau']\n", - " g = settings[i]['g']\n", - " sigma = settings[i]['sigma']\n", - " steps = settings[i]['steps']\n", + " # if resume and (i == i_resume):\n", + " # # should these values be written into settings[i]?\n", + " # h = settings_i['h']\n", + " # gamma = settings_i['gamma']\n", + " # alpha = settings_i['alpha']\n", + " # tau = settings_i['tau']\n", + " # g = settings_i['g']\n", + " # sigma = settings_i['sigma']\n", + " # steps = settings_i['steps']\n", + " # else:\n", + " h = settings[i]['h']\n", + " gamma = settings[i]['gamma']\n", + " alpha = settings[i]['alpha']\n", + " tau = settings[i]['tau']\n", + " g = settings[i]['g']\n", + " sigma = settings[i]['sigma']\n", + " steps = settings[i]['steps']\n", "\n", " h = torch.tensor(h, device=x.device)\n", " gamma = torch.tensor(gamma, device=x.device)\n", From 37fe51b5f176d7172186c2d9426215e2cb10cc1a Mon Sep 17 00:00:00 2001 From: David Marx Date: Tue, 28 Feb 2023 13:30:31 -0800 Subject: [PATCH 04/14] typehints --- Stable_Diffusion_KLMC2_Animation.ipynb | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/Stable_Diffusion_KLMC2_Animation.ipynb b/Stable_Diffusion_KLMC2_Animation.ipynb index cf78880..10dd82a 100644 --- a/Stable_Diffusion_KLMC2_Animation.ipynb +++ b/Stable_Diffusion_KLMC2_Animation.ipynb @@ -570,17 +570,17 @@ "def sample_mcmc_klmc2(\n", " sd_model, \n", " init_image,\n", - " height,\n", - " width,\n", - " n, \n", - " hvp_method='reverse', \n", - " prompts=None,\n", - " settings=None,\n", - " resume = False,\n", - " resume_from=-1,\n", - " img_init_steps=None,\n", - " stuff_to_plot=None,\n", - " checkpoint_every=10,\n", + " height:int,\n", + " width:int,\n", + " n:int, \n", + " hvp_method:str='reverse', \n", + " prompts:list=None,\n", + " settings:ParameterGroup=None,\n", + " resume:bool = False,\n", + " resume_from:int=-1,\n", + " img_init_steps:int=None,\n", + " stuff_to_plot:list=None,\n", + " checkpoint_every:int=10,\n", "):\n", "\n", " if stuff_to_plot is None:\n", From 826f5990924a4a99c68871a0ac27925b62509082 Mon Sep 17 00:00:00 2001 From: David Marx Date: Wed, 1 Mar 2023 21:36:25 -0800 Subject: [PATCH 05/14] basic settings save/load --- Stable_Diffusion_KLMC2_Animation.ipynb | 63 ++++++++++++++++++++++---- 1 file changed, 55 insertions(+), 8 deletions(-) diff --git a/Stable_Diffusion_KLMC2_Animation.ipynb b/Stable_Diffusion_KLMC2_Animation.ipynb index 10dd82a..faa3192 100644 --- a/Stable_Diffusion_KLMC2_Animation.ipynb +++ b/Stable_Diffusion_KLMC2_Animation.ipynb @@ -85,7 +85,8 @@ "execution_count": null, "metadata": { "cellView": "form", - "id": "Ty3IOeXbLzvc" + "id": "Ty3IOeXbLzvc", + "tags": [] }, "outputs": [], "source": [ @@ -122,7 +123,8 @@ "execution_count": null, "metadata": { "cellView": "form", - "id": "kelHR9VM1-hg" + "id": "kelHR9VM1-hg", + "tags": [] }, "outputs": [], "source": [ @@ -165,7 +167,8 @@ "execution_count": null, "metadata": { "cellView": "form", - "id": "fJZtXShcPXx5" + "id": "fJZtXShcPXx5", + "tags": [] }, "outputs": [], "source": [ @@ -908,9 +911,6 @@ "metadata": { "cellView": "form", "id": "yt3d1hww17ST", - "jupyter": { - "source_hidden": true - }, "tags": [] }, "outputs": [], @@ -1318,10 +1318,10 @@ "curved_settings = ParameterGroup({\n", " 'g':SmoothCurve(g),\n", " 'sigma':SmoothCurve(sigma),\n", - " 'h':SmoothCurve(h),\n", + " #'h':SmoothCurve(h),\n", " \n", " # more concise notation for flowers demo:\n", - " #'h':SmoothCurve({0:0.1, 30:0.1, 50:0.3}, bounce=True),\n", + " 'h':SmoothCurve({0:0.1, 30:0.1, 50:0.3}, bounce=True),\n", " #'h':SmoothCurve({0:0.1, 30:0.1, 50:0.3, 70:0.1, 90:0.1}, loop=True),\n", "\n", " 'gamma':SmoothCurve(gamma),\n", @@ -1345,6 +1345,53 @@ " plt.show()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "dir(keyframed.utils)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import keyframed.serialization\n", + "txt = keyframed.serialization.to_yaml(curved_settings, simplify=True)\n", + "#print(txt)\n", + "\n", + "with open(outdir / 'settings.yaml', 'w') as f:\n", + " f.write(txt)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# load settings from disk\n", + "\n", + "load_settings_from_disk = True # @param {type:'boolean'}\n", + "\n", + "if load_settings_from_disk:\n", + " with open(outdir / 'settings.yaml', 'r') as f:\n", + " curved_settings = keyframed.serialization.from_yaml(f.read())\n", + "\n", + "curved_settings.to_dict(simplify=True)['parameters']\n", + "#curved_settings.plot()" + ] + }, { "cell_type": "code", "execution_count": null, From c4486bcec288b85561bcd1291518f175a32653df Mon Sep 17 00:00:00 2001 From: David Marx Date: Wed, 1 Mar 2023 22:00:42 -0800 Subject: [PATCH 06/14] in-comment documentation --- Stable_Diffusion_KLMC2_Animation.ipynb | 33 +++++++++++++++++--------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/Stable_Diffusion_KLMC2_Animation.ipynb b/Stable_Diffusion_KLMC2_Animation.ipynb index faa3192..3825404 100644 --- a/Stable_Diffusion_KLMC2_Animation.ipynb +++ b/Stable_Diffusion_KLMC2_Animation.ipynb @@ -1345,17 +1345,6 @@ " plt.show()" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "dir(keyframed.utils)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -1366,8 +1355,30 @@ "source": [ "import keyframed.serialization\n", "txt = keyframed.serialization.to_yaml(curved_settings, simplify=True)\n", + "\n", "#print(txt)\n", "\n", + "# sigma: 1.25\n", + "#\n", + "# becomes:\n", + "#\n", + "# sigma:\n", + "# curve:\n", + "# - - 0\n", + "# - 1.25\n", + "# - eased_lerp\n", + "#\n", + "# :\n", + "# curve:\n", + "# - - \n", + "# - \n", + "# - \n", + "# - \n", + "# - - \n", + "# - \n", + "# - - \n", + "# - \n", + "\n", "with open(outdir / 'settings.yaml', 'w') as f:\n", " f.write(txt)" ] From e9bab27919c6891f7e2bc4c40d819d0d8db1bade Mon Sep 17 00:00:00 2001 From: David Marx Date: Wed, 1 Mar 2023 22:27:16 -0800 Subject: [PATCH 07/14] Created using Colaboratory --- Stable_Diffusion_KLMC2_Animation.ipynb | 3235 ++++++++++++------------ 1 file changed, 1622 insertions(+), 1613 deletions(-) diff --git a/Stable_Diffusion_KLMC2_Animation.ipynb b/Stable_Diffusion_KLMC2_Animation.ipynb index 3825404..8bec79d 100644 --- a/Stable_Diffusion_KLMC2_Animation.ipynb +++ b/Stable_Diffusion_KLMC2_Animation.ipynb @@ -1,1615 +1,1624 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "15BNHICpOOXg" - }, - "source": [ - "# Stable Diffusion KLMC2 Animation\n", - "\n", - "
\n", - "\n", - "
\n", - "\n", - "\n", - "Notebook by [Katherine Crowson](https://twitter.com/RiversHaveWings), modified by [David Marx](https://twitter.com/DigThatData).\n", - "\n", - "Sponsored by [StabilityAI](https://twitter.com/stabilityai)\n", - "\n", - "Generate animations with [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) 1.4, using the [KLMC2 discretization of underdamped Langevin dynamics](https://arxiv.org/abs/1807.09382). The notebook is largely inspired by [Ajay Jain](https://twitter.com/ajayj_) and [Ben Poole](https://twitter.com/poolio)'s paper [Journey to the BAOAB-limit](https://www.ajayjain.net/journey)—thank you so much for it!\n", - "\n", - "---\n", - "\n", - "## Modifications Provenance\n", - "\n", - "Original notebook URL - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1m8ovBpO2QilE2o4O-p2PONSwqGn4_x2G)\n", - "\n", - "Features and QOL Modifications by [David Marx](https://twitter.com/DigThatData) - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmarx/notebooks/blob/main/Stable_Diffusion_KLMC2_Animation.ipynb)\n", - "\n", - "Last updated date (effectively, notebook version): 2022-02-20\n", - "\n", - "* Keyframed prompts and settings\n", - "* Multiprompt conditioning w independent prompt schedules\n", - "* Set seed for deterministic output\n", - "* Mount Google Drive\n", - "* Faster Setup\n", - "* Init image\n", - "* Alt-checkpoint loading consistent w/deforum\n", - "* Set output filename\n", - "* Fancy GPU info\n", - "* Video embed optional\n", - "* ~~Cheaper default runtime~~ torn about this\n", - "* Local setup\n", - "* New VAE option\n", - "* Smooth interpolation for settings curves\n", - "* Settings curves specified via simple DSL\n", - "* Exposed `refinement_steps` parameter\n", - "* Custom output resolution\n", - "* Optional video upscale\n", - "* Optional resume, user can specify resumption frame (auto-checkpoints every 10 frames)\n", - "* Optional archival\n", - "* Assorted refactoring\n", - "* Debugging plots and animations\n", - "\n", - "## Local Setup\n", - "\n", - "Download the repo containing this notebook and supplementary setup files.\n", - "\n", - "```\n", - "git clone https://github.com/dmarx/notebooks\n", - "cd notebooks\n", - "```\n", - "\n", - "Strongly recommend setting up and activating a virtual environment first. Here's one option that is built into python, windows users in particular might want to consider using anaconda as an alternative.\n", - "\n", - "```bash\n", - "python3 -m venv _venv\n", - "source _venv/bin/activate\n", - "pip install jupyter\n", - "```\n", - "\n", - "With this venv created, in the future you only need to run `source _venv/bin/activate` to activate it.\n", - "\n", - "You can now start a local jupyter instance from the terminal in which the virtual environment is activated by running the `jupyter` command, or alternatively select the new virtualenv as the python environment in your IDE of choice. When you run the notebook's setup cells, it should detect that local setup needs to be performed and modify its setup procedure appropriately.\n", - "\n", - "A common source of errors is user confusion between the python environment running the notebook and an intended virtual environment into which setup has already been performed. To validate that you are using the python environment you think you are, run the command `which python` (this locates the executable associated with the `python` command) both inside the notebook and in a terminal in which your venv is activated: the results should be identical.\n", - "\n", - "## Contact\n", - "\n", - "Report bugs or feature ideas here: https://github.com/dmarx/notebooks/issues" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "15BNHICpOOXg" + }, + "source": [ + "# Stable Diffusion KLMC2 Animation\n", + "\n", + "
\n", + "\n", + "
\n", + "\n", + "\n", + "Notebook by [Katherine Crowson](https://twitter.com/RiversHaveWings), modified by [David Marx](https://twitter.com/DigThatData).\n", + "\n", + "Sponsored by [StabilityAI](https://twitter.com/stabilityai)\n", + "\n", + "Generate animations with [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) 1.4, using the [KLMC2 discretization of underdamped Langevin dynamics](https://arxiv.org/abs/1807.09382). The notebook is largely inspired by [Ajay Jain](https://twitter.com/ajayj_) and [Ben Poole](https://twitter.com/poolio)'s paper [Journey to the BAOAB-limit](https://www.ajayjain.net/journey)—thank you so much for it!\n", + "\n", + "---\n", + "\n", + "## Modifications Provenance\n", + "\n", + "Original notebook URL - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1m8ovBpO2QilE2o4O-p2PONSwqGn4_x2G)\n", + "\n", + "Features and QOL Modifications by [David Marx](https://twitter.com/DigThatData) - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmarx/notebooks/blob/main/Stable_Diffusion_KLMC2_Animation.ipynb)\n", + "\n", + "Last updated date (effectively, notebook version): 2022-02-20\n", + "\n", + "* Keyframed prompts and settings\n", + "* Multiprompt conditioning w independent prompt schedules\n", + "* Set seed for deterministic output\n", + "* Mount Google Drive\n", + "* Faster Setup\n", + "* Init image\n", + "* Alt-checkpoint loading consistent w/deforum\n", + "* Set output filename\n", + "* Fancy GPU info\n", + "* Video embed optional\n", + "* ~~Cheaper default runtime~~ torn about this\n", + "* Local setup\n", + "* New VAE option\n", + "* Smooth interpolation for settings curves\n", + "* Settings curves specified via simple DSL\n", + "* Exposed `refinement_steps` parameter\n", + "* Custom output resolution\n", + "* Optional video upscale\n", + "* Optional resume, user can specify resumption frame (auto-checkpoints every 10 frames)\n", + "* Optional archival\n", + "* Assorted refactoring\n", + "* Debugging plots and animations\n", + "\n", + "## Local Setup\n", + "\n", + "Download the repo containing this notebook and supplementary setup files.\n", + "\n", + "```\n", + "git clone https://github.com/dmarx/notebooks\n", + "cd notebooks\n", + "```\n", + "\n", + "Strongly recommend setting up and activating a virtual environment first. Here's one option that is built into python, windows users in particular might want to consider using anaconda as an alternative.\n", + "\n", + "```bash\n", + "python3 -m venv _venv\n", + "source _venv/bin/activate\n", + "pip install jupyter\n", + "```\n", + "\n", + "With this venv created, in the future you only need to run `source _venv/bin/activate` to activate it.\n", + "\n", + "You can now start a local jupyter instance from the terminal in which the virtual environment is activated by running the `jupyter` command, or alternatively select the new virtualenv as the python environment in your IDE of choice. When you run the notebook's setup cells, it should detect that local setup needs to be performed and modify its setup procedure appropriately.\n", + "\n", + "A common source of errors is user confusion between the python environment running the notebook and an intended virtual environment into which setup has already been performed. To validate that you are using the python environment you think you are, run the command `which python` (this locates the executable associated with the `python` command) both inside the notebook and in a terminal in which your venv is activated: the results should be identical.\n", + "\n", + "## Contact\n", + "\n", + "Report bugs or feature ideas here: https://github.com/dmarx/notebooks/issues" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "Ty3IOeXbLzvc", + "tags": [] + }, + "outputs": [], + "source": [ + "#@title Check GPU\n", + "#!nvidia-smi\n", + "\n", + "import pandas as pd\n", + "import subprocess\n", + "\n", + "def gpu_info():\n", + " outv = subprocess.run([\n", + " 'nvidia-smi',\n", + " # these lines concatenate into a single query string\n", + " '--query-gpu='\n", + " 'timestamp,'\n", + " 'name,'\n", + " 'utilization.gpu,'\n", + " 'utilization.memory,'\n", + " 'memory.used,'\n", + " 'memory.free,'\n", + " ,\n", + " '--format=csv'\n", + " ],\n", + " stdout=subprocess.PIPE).stdout.decode('utf-8')\n", + "\n", + " header, rec = outv.split('\\n')[:-1]\n", + " return pd.DataFrame({' '.join(k.strip().split('.')).capitalize():v for k,v in zip(header.split(','), rec.split(','))}, index=[0]).T\n", + "\n", + "gpu_info()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "kelHR9VM1-hg", + "tags": [] + }, + "outputs": [], + "source": [ + "#@title Install Dependencies\n", + "\n", + "# @markdown Your runtime will automatically restart after running this cell.\n", + "# @markdown You should only need to run this cell when setting up a new runtime. After future runtime restarts,\n", + "# @markdown you should be able to skip this cell.\n", + "\n", + "import warnings\n", + "\n", + "probably_using_colab = False\n", + "try:\n", + " import google\n", + " probably_using_colab = True\n", + "except ImportError:\n", + " warnings.warn(\"Unable to import `google`, assuming this means we're using a local runtime\")\n", + "\n", + "# @markdown Not recommended for colab users. This notebook is currently configured to only make this\n", + "# @markdown option available for local install.\n", + "use_xformers = False\n", + "\n", + "try:\n", + " import keyframed\n", + "except ImportError:\n", + " if probably_using_colab:\n", + " !pip install ftfy einops braceexpand requests transformers clip open_clip_torch omegaconf pytorch-lightning kornia k-diffusion ninja omegaconf\n", + " !pip install -U git+https://github.com/huggingface/huggingface_hub\n", + " !pip install napm keyframed\n", + " else:\n", + " !pip install -r klmc2/requirements.txt\n", + " if use_xformers:\n", + " !pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers\n", + "\n", + " exit() # restarts the runtime" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "fJZtXShcPXx5", + "tags": [] + }, + "outputs": [], + "source": [ + "# @markdown # Setup Workspace { display-mode: \"form\" }\n", + "\n", + "###################\n", + "# Setup Workspace #\n", + "###################\n", + "\n", + "import os\n", + "from pathlib import Path\n", + "import warnings\n", + "\n", + "probably_using_colab = False\n", + "try:\n", + " import google\n", + " if Path('/content').exists():\n", + " probably_using_colab = True\n", + " print(\"looks like we're in colab\")\n", + " else:\n", + " print(\"looks like we're not in colab\")\n", + "except ImportError:\n", + " warnings.warn(\"Unable to import `google`, assuming this means we're using a local runtime\")\n", + "\n", + "\n", + "mount_gdrive = True # @param {type:'boolean'}\n", + "\n", + "# defaults\n", + "outdir = Path('./frames')\n", + "if not os.environ.get('XDG_CACHE_HOME'):\n", + " os.environ['XDG_CACHE_HOME'] = str(Path('~/.cache').expanduser())\n", + "\n", + "if mount_gdrive and probably_using_colab:\n", + " from google.colab import drive\n", + " drive.mount('/content/drive')\n", + " Path('/content/drive/MyDrive/AI/models/.cache/').mkdir(parents=True, exist_ok=True) \n", + " os.environ['XDG_CACHE_HOME']='/content/drive/MyDrive/AI/models/.cache'\n", + " outdir = Path('/content/drive/MyDrive/AI/klmc2/frames/')\n", + "\n", + "# make sure the paths we need exist\n", + "outdir.mkdir(parents=True, exist_ok=True)\n", + "debug_dir = outdir.parent / 'debug_frames'\n", + "debug_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + "\n", + "os.environ['NAPM_PATH'] = str( Path(os.environ['XDG_CACHE_HOME']) / 'napm' )\n", + "Path(os.environ['NAPM_PATH']).mkdir(parents=True, exist_ok=True)\n", + "\n", + "\n", + "import napm\n", + "\n", + "url = 'https://github.com/Stability-AI/stablediffusion'\n", + "napm.pseudoinstall_git_repo(url, add_install_dir_to_path=True)\n", + "\n", + "\n", + "##### Moved from model loading cell\n", + "\n", + "if probably_using_colab:\n", + " models_path = \"/content/models\" #@param {type:\"string\"}\n", + "else:\n", + " models_path = os.environ['XDG_CACHE_HOME']\n", + "\n", + "if mount_gdrive and probably_using_colab:\n", + " models_path_gdrive = \"/content/drive/MyDrive/AI/models\" #@param {type:\"string\"}\n", + " models_path = models_path_gdrive\n", + "\n", + "if not Path(models_path).exists():\n", + " Path(models_path).mkdir(parents=True, exist_ok=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "y2jXKIf2ZkT8", + "tags": [] + }, + "outputs": [], + "source": [ + "# @markdown # Imports and Definitions { display-mode: \"form\" }\n", + "\n", + "###########\n", + "# imports #\n", + "###########\n", + "\n", + "# importing napm puts the stable diffusion repo on the PATH, which is where `ldm` imports from\n", + "import napm\n", + "from ldm.util import instantiate_from_config\n", + "\n", + "from base64 import b64encode\n", + "from collections import defaultdict\n", + "from concurrent import futures\n", + "import math\n", + "from pathlib import Path\n", + "import random\n", + "import re\n", + "import requests\n", + "from requests.exceptions import HTTPError\n", + "import sys\n", + "import time\n", + "from urllib.parse import urlparse\n", + "import warnings\n", + "\n", + "import functorch\n", + "import huggingface_hub\n", + "from IPython.display import display, Video, HTML\n", + "import k_diffusion as K\n", + "from keyframed import Curve, ParameterGroup, SmoothCurve\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np \n", + "from omegaconf import OmegaConf\n", + "import PIL\n", + "from PIL import Image\n", + "import torch\n", + "from torch import nn\n", + "from tqdm.auto import tqdm, trange\n", + "\n", + "from loguru import logger\n", + "import time\n", + "from natsort import natsorted\n", + "\n", + "\n", + "#########################\n", + "# Define useful globals #\n", + "#########################\n", + "\n", + "cpu = torch.device(\"cpu\")\n", + "device = torch.device(\"cuda\")\n", + "\n", + "\n", + "##############################\n", + "# Define necessary functions #\n", + "##############################\n", + " \n", + "import matplotlib.image\n", + "import numpy as np\n", + " \n", + "def get_latest_frame(i=None, latest_frame_fpath=None):\n", + " latest_frame = None\n", + " if latest_frame_fpath is not None:\n", + " latest_frame = latest_frame_fpath\n", + " \n", + " if (latest_frame is None) and (i is None):\n", + " frames = Path('frames').glob(\"*.png\")\n", + " #logger.debug(f\"frames: {len(frames)}\")\n", + " latest_frame = natsort.sort(frames)[-1]\n", + " i = re.findall('out_([0-9]+).png', latest_frame.name)\n", + " else:\n", + " latest_frame = Path('frames') / f\"out_{i:05}.png\"\n", + " logger.debug(f'-latest_frame: {latest_frame}')\n", + " #return Image.open(latest_frame)\n", + " img = matplotlib.image.imread(latest_frame)\n", + " return np.flip(img, axis=0) # up/down\n", + "\n", + "def plot_prompts(prompts=None, n=1000, settings=None, **kargs):\n", + " if prompts is not None:\n", + " for prompt in prompts:\n", + " prompt.weight.plot(n=n, **kargs)\n", + "\n", + "def plot_param(param, settings=None, prompts=None, n=1000, **kargs):\n", + " settings.parameters[param].plot(n=n, **kargs)\n", + " \n", + "# move imports up\n", + "import base64\n", + "from io import BytesIO\n", + "from functools import partial\n", + " \n", + "@logger.catch\n", + "def write_debug_frame_at_(\n", + " i=None,\n", + " n=300, \n", + " prompts=None, \n", + " stuff_to_plot=['prompts'], \n", + " latest_frame_fpath=None,\n", + " pil_image=None,\n", + " settings=None,\n", + "):\n", + " plotting_funcs = {\n", + " 'prompts': plot_prompts,\n", + " 'g': partial(plot_param, param='g'),\n", + " 'h': partial(plot_param, param='h'),\n", + " 'sigma': partial(plot_param, param='sigma'),\n", + " 'gamma': partial(plot_param, param='gamma'),\n", + " 'alpha': partial(plot_param, param='alpha'),\n", + " 'tau': partial(plot_param, param='tau'),\n", + " }\n", + " \n", + " # i feel like this line of code justifies the silly variable name\n", + " if not stuff_to_plot:\n", + " return\n", + " \n", + " #stuff_to_plot = []\n", + " \n", + " test_im = pil_image\n", + " if pil_image is None:\n", + " test_im = get_latest_frame(i, latest_frame_fpath)\n", + "\n", + " fig = plt.figure()\n", + " #axsRight = fig.subplots(3, 1, sharex=True)\n", + " #ax = axsRight[0]\n", + " ax_objs = fig.subplots(len(stuff_to_plot), 1, sharex=True)\n", + " \n", + " #width, height = test_im.size\n", + " height, width = test_im.size\n", + " fig.set_size_inches(height/fig.dpi, width/fig.dpi )\n", + " \n", + " buffer = BytesIO()\n", + " for j, category in enumerate(stuff_to_plot):\n", + " ax = ax_objs\n", + " if len(stuff_to_plot) > 1:\n", + " ax = ax_objs[j]\n", + " plt.sca(ax)\n", + " plt.tight_layout()\n", + " plt.axis('off')\n", + " \n", + " plotting_funcs[category](prompts=prompts, settings=settings, n=n, zorder=1)\n", + " plt.axvline(x=i)\n", + " \n", + " \n", + "\n", + " #plt.margins(0)\n", + " fig.savefig(buffer, transparent=True) \n", + " plt.close()\n", + "\n", + " buffer.seek(0)\n", + " plot_pil = Image.open(buffer)\n", + " #buffer.close() # throws error here\n", + "\n", + " #debug_im_path = Path('debug_frames') / f\"{category}_out_{i:05}.png\"\n", + " #debug_im_path = Path('debug_frames') / f\"debug_out_{i:05}.png\"\n", + " debug_im_path = debug_dir / f\"debug_out_{i:05}.png\"\n", + " test_im = test_im.convert('RGBA')\n", + " test_im.paste(plot_pil, (0,0), plot_pil)\n", + " test_im.save(debug_im_path)\n", + " #display(test_im) # maybe?\n", + " buffer.close() # I guess?\n", + " \n", + " return test_im, plot_pil\n", + "\n", + "##############################\n", + "\n", + "class Prompt:\n", + " def __init__(\n", + " self,\n", + " text,\n", + " weight_schedule,\n", + " ):\n", + " c = sd_model.get_learned_conditioning([text])\n", + " self.text=text\n", + " self.encoded=c\n", + " self.weight = SmoothCurve(weight_schedule)\n", + "\n", + "\n", + "def handle_chigozienri_curve_format(value_string):\n", + " if value_string.startswith('(') and value_string.endswith(')'):\n", + " value_string = value_string[1:-1]\n", + " return value_string\n", + "\n", + "def parse_curve_string(txt, f=float):\n", + " schedule = {}\n", + " for tokens in txt.split(','):\n", + " k,v = tokens.split(':')\n", + " v = handle_chigozienri_curve_format(v)\n", + " schedule[int(k)] = f(v)\n", + " return schedule\n", + "\n", + "def parse_curvable_string(param, is_int=False):\n", + " if isinstance(param, dict):\n", + " return param\n", + " f = float\n", + " if is_int:\n", + " f = int\n", + " try:\n", + " return f(param)\n", + " except ValueError:\n", + " return parse_curve_string(txt=param, f=f)\n", + "\n", + "##################\n", + "\n", + "def show_video(video_path, video_width=512):\n", + " return display(Video(video_path, width=video_width))\n", + "\n", + "if probably_using_colab:\n", + " def show_video(video_path, video_width=512):\n", + " video_file = open(video_path, \"r+b\").read()\n", + " video_url = f\"data:video/mp4;base64,{b64encode(video_file).decode()}\"\n", + " return display(HTML(f\"\"\"\"\"\"))\n", + "\n", + "##################\n", + "\n", + "class NormalizingCFGDenoiser(nn.Module):\n", + " def __init__(self, model, g):\n", + " super().__init__()\n", + " self.inner_model = model\n", + " self.g = g\n", + " self.eps_norms = defaultdict(lambda: (0, 0))\n", + "\n", + " def mean_sq(self, x):\n", + " return x.pow(2).flatten(1).mean(1)\n", + "\n", + " @torch.no_grad()\n", + " def update_eps_norm(self, eps, sigma):\n", + " sigma = sigma[0].item()\n", + " eps_norm = self.mean_sq(eps).mean()\n", + " eps_norm_avg, count = self.eps_norms[sigma]\n", + " eps_norm_avg = eps_norm_avg * count / (count + 1) + eps_norm / (count + 1)\n", + " self.eps_norms[sigma] = (eps_norm_avg, count + 1)\n", + " return eps_norm_avg\n", + "\n", + " def forward(self, x, sigma, uncond, cond, g):\n", + " x_in = torch.cat([x] * 2)\n", + " sigma_in = torch.cat([sigma] * 2)\n", + " cond_in = torch.cat([uncond, cond])\n", + "\n", + " denoised = self.inner_model(x_in, sigma_in, cond=cond_in)\n", + " eps = K.sampling.to_d(x_in, sigma_in, denoised)\n", + " eps_uc, eps_c = eps.chunk(2)\n", + " eps_norm = self.update_eps_norm(eps, sigma).sqrt()\n", + " c = eps_c - eps_uc\n", + " cond_scale = g * eps_norm / self.mean_sq(c).sqrt()\n", + " eps_final = eps_uc + c * K.utils.append_dims(cond_scale, x.ndim)\n", + " return x - eps_final * K.utils.append_dims(sigma, eps.ndim)\n", + "\n", + "#########################\n", + "\n", + "def write_klmc2_state(**state):\n", + " st = time.time()\n", + " obj = {}\n", + " for k,v in state.items():\n", + " try:\n", + " v = v.clone().detach().cpu()\n", + " except AttributeError:\n", + " # if it doesn't have a detach method, we don't need to worry about any preprocessing\n", + " pass\n", + " obj[k] = v\n", + "\n", + " checkpoint_fpath = Path(outdir) / f\"klmc2_state_{state.get('i',0):05}.ckpt\"\n", + " with open(checkpoint_fpath, 'wb') as f:\n", + " torch.save(obj, f=f)\n", + " et = time.time()\n", + " #logger.debug(f\"checkpointing: {et-st}\")\n", + " # to do: move to callback? thread executor, anyway\n", + "\n", + "def read_klmc2_state(root=outdir, latest_frame=-1):\n", + " state = {}\n", + " checkpoints = [str(p) for p in Path(root).glob(\"*.ckpt\")]\n", + " if not checkpoints:\n", + " return None\n", + " checkpoints = natsorted(checkpoints)\n", + " if latest_frame < 0:\n", + " ckpt_fpath = checkpoints[-1]\n", + " else:\n", + " for fname in checkpoints:\n", + " frame_id = re.findall(r'([0-9]+).ckpt', fname)[0]\n", + " if int(frame_id) <= latest_frame:\n", + " ckpt_fpath = fname\n", + " else:\n", + " break\n", + " logger.debug(ckpt_fpath)\n", + " with open(ckpt_fpath,'rb') as f:\n", + " state = torch.load(f=f,map_location='cuda')\n", + " return state\n", + "\n", + "def load_init_image(init_image, height, width):\n", + " if not Path(init_image).exists():\n", + " raise FileNotFoundError(f\"Unable to locate init image from path: {init_image}\")\n", + " \n", + " \n", + " from PIL import Image\n", + " import numpy as np\n", + "\n", + " init_im_pil = Image.open(init_image)\n", + "\n", + " #x_pil = init_im_pil.resize([512,512])\n", + " x_pil = init_im_pil.resize([height,width])\n", + " x_np = np.array(x_pil.convert('RGB')).astype(np.float16) / 255.0\n", + " x = x_np[None].transpose(0, 3, 1, 2)\n", + " x = 2.*x - 1.\n", + " x = torch.from_numpy(x).to('cuda')\n", + " return x\n", + "\n", + "def save_image_fn(image, name, i, n, prompts=None, settings=None, stuff_to_plot=['prompts']):\n", + " pil_image = K.utils.to_pil_image(image)\n", + " if i % 10 == 0 or i == n - 1:\n", + " print(f'\\nIteration {i}/{n}:')\n", + " display(pil_image)\n", + " if i == n - 1:\n", + " print('\\nDone!')\n", + " pil_image.save(name)\n", + " if stuff_to_plot:\n", + " #logger.debug(stuff_to_plot)\n", + " #write_debug_frame_at_(i, prompts=prompts)\n", + " debug_frame, debug_plot = write_debug_frame_at_(i=i,n=n, prompts=prompts, settings=settings, stuff_to_plot=stuff_to_plot, pil_image=pil_image)\n", + " if i % 10 == 0 or i == n - 1:\n", + " #display(debug_frame)\n", + " display(debug_plot)\n", + "\n", + "###############################\n", + "\n", + "@torch.no_grad()\n", + "def sample_mcmc_klmc2(\n", + " sd_model, \n", + " init_image,\n", + " height:int,\n", + " width:int,\n", + " n:int, \n", + " hvp_method:str='reverse', \n", + " prompts:list=None,\n", + " settings:ParameterGroup=None,\n", + " resume:bool = False,\n", + " resume_from:int=-1,\n", + " img_init_steps:int=None,\n", + " stuff_to_plot:list=None,\n", + " checkpoint_every:int=10,\n", + "):\n", + "\n", + " if stuff_to_plot is None:\n", + " stuff_to_plot = ['prompts','h']\n", + " \n", + " torch.cuda.empty_cache()\n", + "\n", + " wrappers = {'eps': K.external.CompVisDenoiser, 'v': K.external.CompVisVDenoiser}\n", + " g = settings[0]['g']\n", + "\n", + " model_wrap = wrappers[sd_model.parameterization](sd_model)\n", + " model_wrap_cfg = NormalizingCFGDenoiser(model_wrap, g)\n", + " sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()\n", + " model = model_wrap_cfg\n", + "\n", + " uc = sd_model.get_learned_conditioning([''])\n", + " extra_args = {\n", + " 'uncond': uc, \n", + " #'cond': prompts[0].encoded,\n", + " 'g': settings[0]['g']\n", + " }\n", + "\n", + " sigma = settings[0]['sigma']\n", + "\n", + " with torch.cuda.amp.autocast(), futures.ThreadPoolExecutor() as ex:\n", + " def callback(info):\n", + " i = info['i']\n", + " rgb = sd_model.decode_first_stage(info['denoised'] )\n", + " ex.submit(save_image_fn, image=rgb, name=(outdir / f'out_{i:05}.png'), i=i, n=n, prompts=prompts, settings=settings, stuff_to_plot=stuff_to_plot)\n", + "\n", + " # Initialize the chain\n", + " print('Initializing the chain...')\n", + "\n", + " # to do: if resuming, generating this init image is unnecessary\n", + " x = None\n", + " if init_image:\n", + " print(\"loading init image\")\n", + " x = load_init_image(init_image, height, width)\n", + " # convert RGB to latent\n", + " x = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(x))\n", + " print(\"init image loaded.\")\n", + "\n", + " print('Actually doing the sampling...')\n", + "\n", + " i_resume=0\n", + " v = None\n", + " if resume:\n", + " state = read_klmc2_state(latest_frame=resume_from)\n", + " if state:\n", + " x, v, i_resume = state['x'], state['v'], state['i']\n", + " # to do: resumption of settings\n", + " settings_i = state['settings_i']\n", + " settings[i]['h'] = settings_i['h']\n", + " settings[i]['gamma'] = settings_i['gamma']\n", + " settings[i]['alpha'] = settings_i['alpha']\n", + " settings[i]['tau'] = settings_i['tau']\n", + " settings[i]['g'] = settings_i['g']\n", + " settings[i]['sigma'] = settings_i['sigma']\n", + " settings[i]['steps'] = settings_i['steps']\n", + " \n", + " # to do: use multicond for init image\n", + " # we want this test after resumption if resuming\n", + " if x is None:\n", + " print(\"No init image provided, generating a random init image\")\n", + " extra_args['cond'] = prompts[0].encoded\n", + " h=height//8\n", + " w=width//8\n", + " x = torch.randn([1, 4, h, w], device=device) * sigma_max\n", + " sigmas_pre = K.sampling.get_sigmas_karras(img_init_steps, sigma, sigma_max, device=x.device)[:-1]\n", + " x = K.sampling.sample_dpmpp_sde(model_wrap_cfg, x, sigmas_pre, extra_args=extra_args)\n", + "\n", + " # if not resuming, randomly initialize momentum\n", + " # this needs to be *after* generating X if we're going to...\n", + " if v is None:\n", + " v = torch.randn_like(x) * sigma\n", + "\n", + " # main sampling loop\n", + " for i in trange(n):\n", + " # fast-forward loop to resumption index\n", + " if resume and i < i_resume:\n", + " continue\n", + " # if resume and (i == i_resume):\n", + " # # should these values be written into settings[i]?\n", + " # h = settings_i['h']\n", + " # gamma = settings_i['gamma']\n", + " # alpha = settings_i['alpha']\n", + " # tau = settings_i['tau']\n", + " # g = settings_i['g']\n", + " # sigma = settings_i['sigma']\n", + " # steps = settings_i['steps']\n", + " # else:\n", + " h = settings[i]['h']\n", + " gamma = settings[i]['gamma']\n", + " alpha = settings[i]['alpha']\n", + " tau = settings[i]['tau']\n", + " g = settings[i]['g']\n", + " sigma = settings[i]['sigma']\n", + " steps = settings[i]['steps']\n", + "\n", + " h = torch.tensor(h, device=x.device)\n", + " gamma = torch.tensor(gamma, device=x.device)\n", + " alpha = torch.tensor(alpha, device=x.device)\n", + " tau = torch.tensor(tau, device=x.device)\n", + " sigma = torch.tensor(sigma, device=x.device)\n", + " steps = int(steps)\n", + " \n", + " sigmas = K.sampling.get_sigmas_karras(steps, sigma_min, sigma.item(), device=x.device)[:-1]\n", + "\n", + " x, v, grad = klmc2_step(\n", + " model,\n", + " prompts,\n", + " x,\n", + " v,\n", + " h,\n", + " gamma,\n", + " alpha,\n", + " tau,\n", + " g,\n", + " sigma,\n", + " sigmas,\n", + " steps,\n", + " hvp_method,\n", + " i,\n", + " callback,\n", + " extra_args,\n", + " )\n", + "\n", + " save_checkpoint = (i % checkpoint_every) == 0\n", + " if save_checkpoint:\n", + " settings_i = settings[i]\n", + " ex.submit(write_klmc2_state, v=v, x=x, i=i, settings_i=settings_i)\n", + " logger.debug(settings[i])\n", + "\n", + "\n", + "def hvp_fn_forward_functorch(model, x, sigma, v, alpha, extra_args):\n", + " def grad_fn(x, sigma):\n", + " s_in = x.new_ones([x.shape[0]])\n", + " denoised = model(x, sigma * s_in, **extra_args)\n", + " return (x - denoised) + alpha * x\n", + " jvp_fn = lambda v: functorch.jvp(grad_fn, (x, sigma), (v, torch.zeros_like(sigma)))\n", + " grad, jvp_out = functorch.vmap(jvp_fn)(v)\n", + " return grad[0], jvp_out\n", + "\n", + "def hvp_fn_reverse(model, x, sigma, v, alpha, extra_args):\n", + " def grad_fn(x, sigma):\n", + " s_in = x.new_ones([x.shape[0]])\n", + " denoised = model(x, sigma * s_in, **extra_args)\n", + " return (x - denoised) + alpha * x\n", + " vjps = []\n", + " with torch.enable_grad():\n", + " x_ = x.clone().requires_grad_()\n", + " grad = grad_fn(x_, sigma)\n", + " for k, item in enumerate(v):\n", + " vjp_out = torch.autograd.grad(grad, x_, item, retain_graph=k < len(v) - 1)[0]\n", + " vjps.append(vjp_out)\n", + " return grad, torch.stack(vjps)\n", + "\n", + "def hvp_fn_zero(model, x, sigma, v, alpha, extra_args):\n", + " def grad_fn(x, sigma):\n", + " s_in = x.new_ones([x.shape[0]])\n", + " denoised = model(x, sigma * s_in, **extra_args)\n", + " return (x - denoised) + alpha * x\n", + " return grad_fn(x, sigma), torch.zeros_like(v)\n", + "\n", + "def hvp_fn_fake(model, x, sigma, v, alpha, extra_args):\n", + " def grad_fn(x, sigma):\n", + " s_in = x.new_ones([x.shape[0]])\n", + " denoised = model(x, sigma * s_in, **extra_args)\n", + " return (x - denoised) + alpha * x\n", + " return grad_fn(x, sigma), (1 + alpha) * v\n", + "\n", + "\n", + "def multicond_hvp(model, x, sigma, v, alpha, extra_args, prompts, hvp_fn, i):\n", + "\n", + " # loop over prompts and aggregate gradients for multicond\n", + " grad = torch.zeros_like(x)\n", + " h2_v = torch.zeros_like(x)\n", + " h2_noise_v2 = torch.zeros_like(x)\n", + " h2_noise_x2 = torch.zeros_like(x)\n", + " wt_norm = 0\n", + " for prompt in prompts:\n", + " wt = prompt.weight[i]\n", + " if wt == 0:\n", + " continue\n", + " wt_norm += wt\n", + " wt = torch.tensor(wt, device=x.device)\n", + " extra_args['cond'] = prompt.encoded\n", + "\n", + " # Estimate gradient and hessian\n", + " grad_, (h2_v_, h2_noise_v2_, h2_noise_x2_) = hvp_fn(\n", + " model=model,\n", + " x=x, \n", + " sigma=sigma, \n", + " v=v,\n", + " alpha=alpha,\n", + " extra_args=extra_args,\n", + " )\n", + "\n", + " grad = grad + grad_ * wt \n", + " h2_v = h2_v + h2_v_ * wt\n", + " h2_noise_v2 = h2_noise_v2 + h2_noise_v2_ * wt\n", + " h2_noise_x2 = h2_noise_x2 + h2_noise_x2_ * wt\n", + "\n", + " # Normalize gradient to magnitude it'd have if just single prompt w/ wt=1.\n", + " # simplifies multicond w/o deep frying image or adding hyperparams\n", + " grad = grad / wt_norm \n", + " h2_v = h2_v / wt_norm\n", + " h2_noise_v2 = h2_noise_v2 / wt_norm\n", + " h2_noise_x2 = h2_noise_x2 / wt_norm\n", + "\n", + " return grad, h2_v, h2_noise_v2, h2_noise_v2, h2_noise_x2\n", + "\n", + "\n", + "\n", + "\n", + "def klmc2_step(\n", + " model,\n", + " prompts,\n", + " x,\n", + " v,\n", + " h,\n", + " gamma,\n", + " alpha,\n", + " tau,\n", + " g,\n", + " sigma,\n", + " sigmas,\n", + " steps,\n", + " hvp_method,\n", + " i,\n", + " callback,\n", + " extra_args,\n", + " ):\n", + "\n", + " #s_in = x.new_ones([x.shape[0]])\n", + "\n", + " # Model helper functions\n", + "\n", + " hvp_fns = {'forward-functorch': hvp_fn_forward_functorch,\n", + " 'reverse': hvp_fn_reverse,\n", + " 'zero': hvp_fn_zero,\n", + " 'fake': hvp_fn_fake}\n", + "\n", + " hvp_fn = hvp_fns[hvp_method]\n", + "\n", + " # KLMC2 helper functions\n", + " def psi_0(gamma, t):\n", + " return torch.exp(-gamma * t)\n", + "\n", + " def psi_1(gamma, t):\n", + " return -torch.expm1(-gamma * t) / gamma\n", + "\n", + " def psi_2(gamma, t):\n", + " return (torch.expm1(-gamma * t) + gamma * t) / gamma ** 2\n", + "\n", + " def phi_2(gamma, t_):\n", + " t = t_.double()\n", + " out = (torch.exp(-gamma * t) * (torch.expm1(gamma * t) - gamma * t)) / gamma ** 2\n", + " return out.to(t_)\n", + "\n", + " def phi_3(gamma, t_):\n", + " t = t_.double()\n", + " out = (torch.exp(-gamma * t) * (2 + gamma * t + torch.exp(gamma * t) * (gamma * t - 2))) / gamma ** 3\n", + " return out.to(t_)\n", + "\n", + "\n", + " # Compute model outputs and sample noise\n", + " x_trapz = torch.linspace(0, h, 1001, device=x.device)\n", + " y_trapz = [fun(gamma, x_trapz) for fun in (psi_0, psi_1, phi_2, phi_3)]\n", + " noise_cov = torch.tensor([[torch.trapz(y_trapz[i] * y_trapz[j], x=x_trapz) for j in range(4)] for i in range(4)], device=x.device)\n", + " noise_v, noise_x, noise_v2, noise_x2 = torch.distributions.MultivariateNormal(x.new_zeros([4]), noise_cov).sample(x.shape).unbind(-1)\n", + "\n", + " extra_args['g']=g\n", + "\n", + " # compute derivatives, multicond wrapper loops over prompts and averages derivatives\n", + " grad, h2_v, h2_noise_v2, h2_noise_v2, h2_noise_x2 = multicond_hvp(\n", + " model=model, \n", + " x=x, \n", + " sigma=sigma, \n", + " v=torch.stack([v, noise_v2, noise_x2]), # need a \"dummy\" v for init image generation\n", + " alpha=alpha, \n", + " extra_args=extra_args, \n", + " prompts=prompts, \n", + " hvp_fn=hvp_fn,\n", + " i=i,\n", + " )\n", + "\n", + " # DPM-Solver++(2M) refinement steps\n", + " x_refine = x\n", + " use_dpm = True\n", + " old_denoised = None\n", + " for j in range(len(sigmas) - 1):\n", + " if j == 0:\n", + " denoised = x_refine - grad\n", + " else:\n", + " s_in = x.new_ones([x.shape[0]])\n", + " denoised = model(x_refine, sigmas[j] * s_in, **extra_args)\n", + " dt_ode = sigmas[j + 1] - sigmas[j]\n", + " if not use_dpm or old_denoised is None or sigmas[j + 1] == 0:\n", + " eps = K.sampling.to_d(x_refine, sigmas[j], denoised)\n", + " x_refine = x_refine + eps * dt_ode\n", + " else:\n", + " h_ode = sigmas[j].log() - sigmas[j + 1].log()\n", + " h_last = sigmas[j - 1].log() - sigmas[j].log()\n", + " fac = h_ode / (2 * h_last)\n", + " denoised_d = (1 + fac) * denoised - fac * old_denoised\n", + " eps = K.sampling.to_d(x_refine, sigmas[j], denoised_d)\n", + " x_refine = x_refine + eps * dt_ode\n", + " old_denoised = denoised\n", + " if callback is not None:\n", + " callback({'i': i, 'denoised': x_refine})\n", + "\n", + " # Update the chain\n", + " noise_std = (2 * gamma * tau * sigma ** 2).sqrt()\n", + " v_next = 0 + psi_0(gamma, h) * v - psi_1(gamma, h) * grad - phi_2(gamma, h) * h2_v + noise_std * (noise_v - h2_noise_v2)\n", + " x_next = x + psi_1(gamma, h) * v - psi_2(gamma, h) * grad - phi_3(gamma, h) * h2_v + noise_std * (noise_x - h2_noise_x2)\n", + " v, x = v_next, x_next\n", + "\n", + " return x, v, grad " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "yt3d1hww17ST", + "tags": [] + }, + "outputs": [], + "source": [ + "#@markdown **Select and Load Model**\n", + "\n", + "## TO DO:\n", + "## - if local, try to load model from ~/.cache/huggingface/diffusers\n", + "\n", + "# modified from:\n", + "# https://github.com/deforum/stable-diffusion/blob/main/Deforum_Stable_Diffusion.ipynb\n", + "\n", + "import napm\n", + "from ldm.util import instantiate_from_config\n", + "\n", + "\n", + "model_config = \"v1-inference.yaml\" #@param [\"custom\",\"v1-inference.yaml\"]\n", + "model_checkpoint = \"sd-v1-4.ckpt\" #@param [\"custom\",\"sd-v1-4-full-ema.ckpt\",\"sd-v1-4.ckpt\",\"sd-v1-3-full-ema.ckpt\",\"sd-v1-3.ckpt\",\"sd-v1-2-full-ema.ckpt\",\"sd-v1-2.ckpt\",\"sd-v1-1-full-ema.ckpt\",\"sd-v1-1.ckpt\", \"robo-diffusion-v1.ckpt\",\"waifu-diffusion-v1-3.ckpt\"]\n", + "if model_checkpoint == \"waifu-diffusion-v1-3.ckpt\":\n", + " model_checkpoint = \"model-epoch05-float16.ckpt\"\n", + "custom_config_path = \"\" #@param {type:\"string\"}\n", + "custom_checkpoint_path = \"\" #@param {type:\"string\"}\n", + "\n", + "half_precision = True # check\n", + "check_sha256 = False #@param {type:\"boolean\"}\n", + "\n", + "model_map = {\n", + " \"sd-v1-4-full-ema.ckpt\": {\n", + " 'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-4-full-ema.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-4.ckpt\": {\n", + " 'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-3-full-ema.ckpt\": {\n", + " 'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/blob/main/sd-v1-3-full-ema.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-3.ckpt\": {\n", + " 'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/resolve/main/sd-v1-3.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-2-full-ema.ckpt\": {\n", + " 'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-2-full-ema.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-2.ckpt\": {\n", + " 'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/resolve/main/sd-v1-2.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-1-full-ema.ckpt\": {\n", + " 'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829',\n", + " 'url':'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1-full-ema.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-1.ckpt\": {\n", + " 'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"robo-diffusion-v1.ckpt\": {\n", + " 'sha256': '244dbe0dcb55c761bde9c2ac0e9b46cc9705ebfe5f1f3a7cc46251573ea14e16',\n", + " 'url': 'https://huggingface.co/nousr/robo-diffusion/resolve/main/models/robo-diffusion-v1.ckpt',\n", + " 'requires_login': False,\n", + " },\n", + " \"model-epoch05-float16.ckpt\": {\n", + " 'sha256': '26cf2a2e30095926bb9fd9de0c83f47adc0b442dbfdc3d667d43778e8b70bece',\n", + " 'url': 'https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/model-epoch05-float16.ckpt',\n", + " 'requires_login': False,\n", + " },\n", + "}\n", + "\n", + "# config path\n", + "ckpt_config_path = custom_config_path if model_config == \"custom\" else os.path.join(models_path, model_config)\n", + "if os.path.exists(ckpt_config_path):\n", + " print(f\"{ckpt_config_path} exists\")\n", + "else:\n", + " #ckpt_config_path = \"./stable-diffusion/configs/stable-diffusion/v1-inference.yaml\"\n", + " ckpt_config_path = \"./v1-inference.yaml\"\n", + " if not Path(ckpt_config_path).exists():\n", + " !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml\n", + " \n", + "print(f\"Using config: {ckpt_config_path}\")\n", + "\n", + "# checkpoint path or download\n", + "ckpt_path = custom_checkpoint_path if model_checkpoint == \"custom\" else os.path.join(models_path, model_checkpoint)\n", + "ckpt_valid = True\n", + "if os.path.exists(ckpt_path):\n", + " print(f\"{ckpt_path} exists\")\n", + "elif 'url' in model_map[model_checkpoint]:\n", + " url = model_map[model_checkpoint]['url']\n", + "\n", + " # CLI dialogue to authenticate download\n", + " if model_map[model_checkpoint]['requires_login']:\n", + " print(\"This model requires an authentication token\")\n", + " print(\"Please ensure you have accepted its terms of service before continuing.\")\n", + "\n", + " username = input(\"What is your huggingface username?:\")\n", + " token = input(\"What is your huggingface token?:\")\n", + "\n", + " _, path = url.split(\"https://\")\n", + "\n", + " url = f\"https://{username}:{token}@{path}\"\n", + "\n", + " # contact server for model\n", + " print(f\"Attempting to download {model_checkpoint}...this may take a while\")\n", + " ckpt_request = requests.get(url)\n", + " request_status = ckpt_request.status_code\n", + "\n", + " # inform user of errors\n", + " if request_status == 403:\n", + " raise ConnectionRefusedError(\"You have not accepted the license for this model.\")\n", + " elif request_status == 404:\n", + " raise ConnectionError(\"Could not make contact with server\")\n", + " elif request_status != 200:\n", + " raise ConnectionError(f\"Some other error has ocurred - response code: {request_status}\")\n", + "\n", + " # write to model path\n", + " with open(os.path.join(models_path, model_checkpoint), 'wb') as model_file:\n", + " model_file.write(ckpt_request.content)\n", + "else:\n", + " print(f\"Please download model checkpoint and place in {os.path.join(models_path, model_checkpoint)}\")\n", + " ckpt_valid = False\n", + "\n", + "if check_sha256 and model_checkpoint != \"custom\" and ckpt_valid:\n", + " import hashlib\n", + " print(\"\\n...checking sha256\")\n", + " with open(ckpt_path, \"rb\") as f:\n", + " bytes = f.read() \n", + " hash = hashlib.sha256(bytes).hexdigest()\n", + " del bytes\n", + " if model_map[model_checkpoint][\"sha256\"] == hash:\n", + " print(\"hash is correct\\n\")\n", + " else:\n", + " print(\"hash in not correct\\n\")\n", + " ckpt_valid = False\n", + "\n", + "if ckpt_valid:\n", + " print(f\"Using ckpt: {ckpt_path}\")\n", + "\n", + "def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True):\n", + " map_location = \"cuda\" #@param [\"cpu\", \"cuda\"]\n", + " print(f\"Loading model from {ckpt}\")\n", + " pl_sd = torch.load(ckpt, map_location=map_location)\n", + " if \"global_step\" in pl_sd:\n", + " print(f\"Global Step: {pl_sd['global_step']}\")\n", + " sd = pl_sd[\"state_dict\"]\n", + " model = instantiate_from_config(config.model)\n", + " m, u = model.load_state_dict(sd, strict=False)\n", + " if len(m) > 0 and verbose:\n", + " print(\"missing keys:\")\n", + " print(m)\n", + " if len(u) > 0 and verbose:\n", + " print(\"unexpected keys:\")\n", + " print(u)\n", + "\n", + " if half_precision:\n", + " model = model.half().to(device)\n", + " else:\n", + " model = model.to(device)\n", + " model.eval()\n", + " return model\n", + "\n", + "if ckpt_valid:\n", + " local_config = OmegaConf.load(f\"{ckpt_config_path}\")\n", + " model = load_model_from_config(local_config, f\"{ckpt_path}\", half_precision=half_precision)\n", + " device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + " model = model.to(device)\n", + "\n", + " # Disable checkpointing as it is not compatible with the method\n", + " for module in model.modules():\n", + " if hasattr(module, 'checkpoint'):\n", + " module.checkpoint = False\n", + " if hasattr(module, 'use_checkpoint'):\n", + " module.use_checkpoint = False\n", + "\n", + " sd_model=model\n", + "\n", + "####################################################################\n", + "\n", + "use_new_vae = True #@param {type:\"boolean\"}\n", + "\n", + "if use_new_vae:\n", + "\n", + " # from kat's notebook again\n", + "\n", + " def download_from_huggingface(repo, filename):\n", + " while True:\n", + " try:\n", + " return huggingface_hub.hf_hub_download(repo, filename)\n", + " except HTTPError as e:\n", + " if e.response.status_code == 401:\n", + " # Need to log into huggingface api\n", + " huggingface_hub.interpreter_login()\n", + " continue\n", + " elif e.response.status_code == 403:\n", + " # Need to do the click through license thing\n", + " print(f'Go here and agree to the click through license on your account: https://huggingface.co/{repo}')\n", + " input('Hit enter when ready:')\n", + " continue\n", + " else:\n", + " raise e\n", + "\n", + " vae_840k_model_path = download_from_huggingface(\"stabilityai/sd-vae-ft-mse-original\", \"vae-ft-mse-840000-ema-pruned.ckpt\")\n", + "\n", + " def load_model_from_config_kc(config, ckpt):\n", + " print(f\"Loading model from {ckpt}\")\n", + " pl_sd = torch.load(ckpt, map_location=\"cpu\")\n", + " sd = pl_sd[\"state_dict\"]\n", + " config = OmegaConf.load(config)\n", + "\n", + " try:\n", + " config['model']['params']['lossconfig']['target'] = \"torch.nn.Identity\"\n", + " print('Patched VAE config.')\n", + " except KeyError:\n", + " pass\n", + "\n", + " model = instantiate_from_config(config.model)\n", + " m, u = model.load_state_dict(sd, strict=False)\n", + " model = model.to(cpu).eval().requires_grad_(False)\n", + " return model\n", + "\n", + " vaemodel_yaml_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/models/first_stage_models/kl-f8/config.yaml\"\n", + " vaemodel_yaml_fname = 'config_vae_kl-f8.yaml'\n", + " vaemodel_yaml_fname_git = \"latent-diffusion/models/first_stage_models/kl-f8/config.yaml\"\n", + " if Path(vaemodel_yaml_fname_git).exists():\n", + " vae_model = load_model_from_config_kc(vaemodel_yaml_fname_git, vae_840k_model_path).half().to(device)\n", + " else:\n", + " if not Path(vaemodel_yaml_fname).exists():\n", + " !wget {vaemodel_yaml_url} -O {vaemodel_yaml_fname}\n", + " vae_model = load_model_from_config_kc(vaemodel_yaml_fname, vae_840k_model_path).half().to(device)\n", + "\n", + " del sd_model.first_stage_model\n", + " sd_model.first_stage_model = vae_model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "ZljSF1ePnBl4", + "tags": [] + }, + "outputs": [], + "source": [ + "# @title Settings\n", + "\n", + "# @markdown The number of frames to sample:\n", + "n = 300 # @param {type:\"integer\"}\n", + "\n", + "# @markdown height and width must be multiples of 8 (e.g. 256, 512, 768, 1024)\n", + "height = 512 # @param {type:\"integer\"}\n", + "\n", + "width = 512 # @param {type:\"integer\"}\n", + "\n", + "\n", + "# @markdown If seed is negative, a random seed will be used\n", + "seed = -1 # @param {type:\"number\"}\n", + "\n", + "init_image = \"\" # @param {type:'string'}\n", + "\n", + "# @markdown ---\n", + "\n", + "# @markdown Settings below this line can be parameterized using keyframe syntax: `\"time:weight, time:weight, ...\". \n", + "# @markdown Over spans where values of weights change, intermediate values will be interpolated using an \"s\" shaped curve.\n", + "# @markdown If a value for keyframe 0 is not specified, it is presumed to be `0:0`.\n", + "\n", + "# @markdown The strength of the conditioning on the prompt:\n", + "g=\"0:0.1\" # @param {type:\"string\"}\n", + "\n", + "# @markdown The noise level to sample at\n", + "# @markdown Ramp up from a tiny sigma if using init image, e.g. `0:0.25, 100:2, ...`\n", + "# @markdown NB: Turning sigma *up* mid generation seems to work fine, but turning sigma *down* mid generation tends to \"deep fry\" the outputs\n", + "sigma = \"1.25\" # @param {type:\"string\"}\n", + "\n", + "# @markdown Step size (range 0 to 1):\n", + "h = \"0:0.1, 30:0.1, 50:0.3, 70:0.1, 120:0.1, 140:.3, 160:.1, 210:.1, 230:.3, 250:.1\" # @param {type:\"string\"}\n", + "\n", + "# @markdown Friction (2 is critically damped, lower -> smoother animation):\n", + "gamma = \"1.1\" # @param {type:\"string\"}\n", + "\n", + "# @markdown Quadratic penalty (\"weight decay\") strength:\n", + "alpha = \"0.005\" # @param {type:\"string\"}\n", + "\n", + "# @markdown Temperature (adjustment to the amount of noise added per step):\n", + "tau = \"1.0\" # @param {type:\"string\"}\n", + "\n", + "# @markdown Denoising refinement steps:\n", + "refinement_steps = \"6\" # @param {type:\"string\"}\n", + "\n", + "# @markdown If an init image is not provided, this is how many steps will be used when generating an initial state:\n", + "img_init_steps = 15 # @param {type:\"number\"}\n", + "\n", + "# @markdown The HVP method:\n", + "# @markdown
`forward-functorch` and `reverse` provide real second derivatives. Compatibility, speed, and memory usage vary by model and xformers configuration.\n", + "# @markdown `fake` is very fast and low memory but inaccurate. `zero` (fallback to first order KLMC) is not recommended.\n", + "hvp_method = 'fake' # @param [\"forward-functorch\", \"reverse\", \"fake\", \"zero\"]\n", + "\n", + "checkpoint_every = 10 # @param {type:\"number\"}\n", + "\n", + "###########################\n", + "\n", + "assert (height % 8) == 0\n", + "assert (width % 8) == 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1pLTsdGBPXx6", + "tags": [] + }, + "outputs": [], + "source": [ + "#@title Prompts\n", + "\n", + "# [ \n", + "# [\"first prompt will be used to initialize the image\", {time:weight, time:weight...}], \n", + "# [\"more prompts if you want\", {...}], \n", + "# ...]\n", + "\n", + "# if a weight for time=0 isn't specified, the weight is assumed to be zero.\n", + "# if you want to \"fade in\" any prompts, it's best to have them start with a small but non-zero value, e.g. 0.001\n", + "\n", + "prompt_params = [\n", + " # # FIRST PROMPT INITIALIZES IMAGE\n", + " #[\"sweetest puppy, golden retriever\", {0:.5, 30:0.5, 100:0.001}],\n", + " #[\"sweet old dog, golden retriever\", {0:0.001, 30:0.001, 100:0.5}],\n", + " #[\"happiest pupper, cutest dog evar, golden retriever, incredibly photogenic dog\", {0:1}],\n", + "\n", + " # # the 'flowers prompts' below go with a particular 'h' setting in the next cell\n", + " [\"incredibly beautiful orchids, a bouquet of orchids\", {0:1, 35:1, 50:0}],\n", + " [\"incredibly beautiful roses, a bouquet of roses\", {0:0.001, 35:0.001, 50:1, 120:1, 140:0}],\n", + " [\"incredibly beautiful carnations, a bouquet of carnations\", {0:0.001, 120:0.001, 140:1, 220:1, 240:0}],\n", + " [\"incredibly beautiful carnations, a bouquet of sunflowers\", {0:0.001, 220:0.001, 240:1}],\n", + " \n", + " # negative prompts\n", + " [\"watermark text\", {0:-0.1} ],\n", + " [\"jpeg artifacts\", {0:-0.1} ],\n", + " [\"artist's signature\", {0:-0.1} ],\n", + " [\"istockphoto, gettyimages, watermarked image\", {0:-0.1} ],\n", + "]\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "hZ0lh-WkdB19", + "tags": [] + }, + "outputs": [], + "source": [ + "# @title Build prompt and settings objects\n", + "\n", + "# @markdown some advanced features demonstrated in commented-out code in this cell\n", + "\n", + "plot_prompt_weight_curves = True # @param {type: 'boolean'}\n", + "plot_settings_weight_curves = True # @param {type: 'boolean'}\n", + "\n", + "#################\n", + "\n", + "# Build Prompt objects\n", + "\n", + "prompts = [\n", + " Prompt(text, weight_schedule) \n", + " for (text, weight_schedule) in prompt_params\n", + "]\n", + "\n", + "# uncomment to loop the prompts\n", + "#for p in prompts:\n", + "# if len(p.weight.keyframes) > 1: # ignore negative prompts\n", + "# p.weight.loop=True \n", + "\n", + "# uncomment to loop prompts in \"bounce\" mode\n", + "#for p in prompts:\n", + "# if len(p.weight.keyframes) > 1:\n", + "# p.weight.bounce=True \n", + "\n", + "#################\n", + "\n", + "# Build Settings object\n", + "\n", + "g = parse_curvable_string(g)\n", + "sigma = parse_curvable_string(sigma)\n", + "h = parse_curvable_string(h)\n", + "gamma = parse_curvable_string(gamma)\n", + "alpha = parse_curvable_string(alpha)\n", + "tau = parse_curvable_string(tau)\n", + "steps = parse_curvable_string(refinement_steps)\n", + "\n", + "\n", + "curved_settings = ParameterGroup({\n", + " 'g':SmoothCurve(g),\n", + " 'sigma':SmoothCurve(sigma),\n", + " #'h':SmoothCurve(h),\n", + " \n", + " # more concise notation for flowers demo:\n", + " 'h':SmoothCurve({0:0.1, 30:0.1, 50:0.3}, bounce=True),\n", + " #'h':SmoothCurve({0:0.1, 30:0.1, 50:0.3, 70:0.1, 90:0.1}, loop=True),\n", + "\n", + " 'gamma':SmoothCurve(gamma),\n", + " 'alpha':SmoothCurve(alpha),\n", + " 'tau':SmoothCurve(tau),\n", + " 'steps':SmoothCurve(steps),\n", + "})\n", + "\n", + "\n", + "if plot_prompt_weight_curves:\n", + " for prompt in prompts:\n", + " prompt.weight.plot(n=n)\n", + " plt.title(\"prompt weight schedules\")\n", + " plt.show()\n", + "\n", + "\n", + "if plot_settings_weight_curves:\n", + " for name, curve in curved_settings.parameters.items():\n", + " curve.plot(n=n)\n", + " plt.title(name)\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [], + "id": "tthag9k67Uey" + }, + "outputs": [], + "source": [ + "# @markdown running this cell saves the current settings to disk\n", + "\n", + "import keyframed.serialization\n", + "txt = keyframed.serialization.to_yaml(curved_settings, simplify=True)\n", + "\n", + "#print(txt)\n", + "\n", + "# sigma: 1.25\n", + "#\n", + "# becomes:\n", + "#\n", + "# sigma:\n", + "# curve:\n", + "# - - 0\n", + "# - 1.25\n", + "# - eased_lerp\n", + "#\n", + "# :\n", + "# curve:\n", + "# - - \n", + "# - \n", + "# - \n", + "# - \n", + "# - - \n", + "# - \n", + "# - - \n", + "# - \n", + "\n", + "with open(outdir / 'settings.yaml', 'w') as f:\n", + " f.write(txt)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [], + "cellView": "form", + "id": "srbY3kDa7Uey" + }, + "outputs": [], + "source": [ + "# load settings from disk\n", + "\n", + "# @markdown override current settings using the contents of `frames/settings.yaml`\n", + "\n", + "import keyframed.serialization\n", + "\n", + "load_settings_from_disk = True # @param {type:'boolean'}\n", + "\n", + "if load_settings_from_disk:\n", + " with open(outdir / 'settings.yaml', 'r') as f:\n", + " curved_settings = keyframed.serialization.from_yaml(f.read())\n", + "\n", + "curved_settings.to_dict(simplify=True)['parameters']\n", + "#curved_settings.plot()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "i-_u1Q0wRqMb", + "tags": [] + }, + "outputs": [], + "source": [ + "#@title Generate Animation Frames\n", + "\n", + "resume = True # @param {type:'boolean'}\n", + "archive_old_work = False # @param {type:'boolean'}\n", + "\n", + "# -1 = most recent frame\n", + "resume_from = -1 # @param {type:'number'}\n", + "\n", + "# @markdown optional debugging plots\n", + "plot_prompt_weights = True # @param {type:'boolean'}\n", + "plot_h = False # @param {type:'boolean'}\n", + "plot_g = False # @param {type:'boolean'}\n", + "plot_sigma = False # @param {type:'boolean'}\n", + "plot_gamma = False # @param {type:'boolean'}\n", + "plot_alpha = False # @param {type:'boolean'}\n", + "plot_tau = False # @param {type:'boolean'}\n", + "\n", + "################\n", + "\n", + "_seed = seed\n", + "if seed < 0: \n", + " _seed = random.randrange(0, 4294967295)\n", + "print(f\"using seed: {_seed}\")\n", + "torch.manual_seed(_seed)\n", + "\n", + "stuff_to_plot = []\n", + "if plot_prompt_weights:\n", + " stuff_to_plot.append('prompts')\n", + "if plot_h:\n", + " stuff_to_plot.append('h')\n", + "if plot_g:\n", + " stuff_to_plot.append('g')\n", + "if plot_sigma:\n", + " stuff_to_plot.append('sigma')\n", + "if plot_gamma:\n", + " stuff_to_plot.append('gamma')\n", + "if plot_alpha:\n", + " stuff_to_plot.append('alpha')\n", + "if plot_tau:\n", + " stuff_to_plot.append('tau')\n", + "\n", + "if not resume:\n", + " if archive_old_work:\n", + " archive_dir = outdir.parent / 'archive' / str(int(time.time()))\n", + " archive_dir.mkdir(parents=True, exist_ok=True)\n", + " print(f\"Archiving contents of /frames, moving to: {archive_dir}\")\n", + " else:\n", + " print(\"Old contents of /frames being deleted. This can be prevented in the future by setting either 'resume' or 'archive_old_work' to True.\")\n", + " for p in outdir.glob(f'*'):\n", + " if archive_old_work:\n", + " target = archive_dir / p.name\n", + " p.rename(target)\n", + " else:\n", + " p.unlink()\n", + " for p in Path('debug_frames').glob(f'*'):\n", + " p.unlink()\n", + "\n", + "sample_mcmc_klmc2(\n", + " sd_model=sd_model,\n", + " init_image=init_image,\n", + " height=height,\n", + " width=width,\n", + " n=n,\n", + " hvp_method=hvp_method,\n", + " prompts=prompts,\n", + " settings=curved_settings,\n", + " resume=resume,\n", + " resume_from=resume_from,\n", + " img_init_steps=img_init_steps,\n", + " stuff_to_plot=stuff_to_plot,\n", + " checkpoint_every=checkpoint_every,\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "DjwY7XrooLX_" + }, + "outputs": [], + "source": [ + "#@title Make the video\n", + "\n", + "if 'width' not in locals():\n", + " width = height = 512\n", + "\n", + "\n", + "# @markdown If your video is larger than a few MB, attempting to embed it will probably crash\n", + "# @markdown the session. If this happens, view the generated video after downloading it first.\n", + "embed_video = True # @param {type:'boolean'}\n", + "download_video = False # @param {type:'boolean'}\n", + "\n", + "upscale_video = False # @param {type:'boolean'}\n", + "\n", + "\n", + "outdir_str = str(outdir)\n", + "\n", + "fps = 14 # @param {type:\"integer\"}\n", + "out_fname = \"out.mp4\" # @param {type: \"string\"}\n", + "\n", + "out_fullpath = str( outdir / out_fname )\n", + "print(f\"Video will be saved to: {out_fullpath}\")\n", + "\n", + "compile_video_cmd = f\"ffmpeg -y -r {fps} -i 'out_%*.png' -crf 15 -preset veryslow -pix_fmt yuv420p \" # {out_fname}\"\n", + "if upscale_video:\n", + " compile_video_cmd += f\"-vf scale={2*width}x{2*height}:flags=lanczos \"\n", + "compile_video_cmd += f\"{out_fname}\"\n", + "\n", + "print('\\nMaking the video...\\n')\n", + "!cd {outdir_str}; {compile_video_cmd}\n", + "\n", + "\n", + "debug=True\n", + "if debug:\n", + " #outdir_str = \"debug_frames\"\n", + " print(\"\\nMaking debug video...\")\n", + " #!cd debug_frames; ffmpeg -y -r {fps} -i 'prompts_out_%*.png' -crf 15 -preset veryslow -pix_fmt yuv420p debug_out.mp4\n", + " !cd {debug_dir}; ffmpeg -y -r {fps} -i 'debug_out_%*.png' -crf 15 -preset veryslow -pix_fmt yuv420p debug_out.mp4\n", + "\n", + "if embed_video:\n", + " print('\\nThe video:')\n", + " show_video(out_fullpath)\n", + " if debug:\n", + " show_video(debug_dir / \"debug_out.mp4\")\n", + "\n", + "if download_video and probably_using_colab:\n", + " from google.colab import files\n", + " files.download(out_fullpath)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rK_GlP_7WJiu", + "tags": [] + }, + "outputs": [], + "source": [ + "#@title Licensed under the MIT License { display-mode: \"form\" }\n", + "\n", + "# Copyright (c) 2022 Katherine Crowson \n", + "# Copyright (c) 2023 David Marx \n", + "# Copyright (c) 2022 deforum and contributors\n", + "\n", + "# Permission is hereby granted, free of charge, to any person obtaining a copy\n", + "# of this software and associated documentation files (the \"Software\"), to deal\n", + "# in the Software without restriction, including without limitation the rights\n", + "# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", + "# copies of the Software, and to permit persons to whom the Software is\n", + "# furnished to do so, subject to the following conditions:\n", + "\n", + "# The above copyright notice and this permission notice shall be included in\n", + "# all copies or substantial portions of the Software.\n", + "\n", + "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", + "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", + "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", + "# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", + "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", + "# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n", + "# THE SOFTWARE." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "private_outputs": true, + "provenance": [] + }, + "gpuClass": "premium", + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "vscode": { + "interpreter": { + "hash": "ff1624fd81a21ea709585fb1fdce5419f857f6a9e76cb1632f1b8b574978f9ee" + } + } }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "Ty3IOeXbLzvc", - "tags": [] - }, - "outputs": [], - "source": [ - "#@title Check GPU\n", - "#!nvidia-smi\n", - "\n", - "import pandas as pd\n", - "import subprocess\n", - "\n", - "def gpu_info():\n", - " outv = subprocess.run([\n", - " 'nvidia-smi',\n", - " # these lines concatenate into a single query string\n", - " '--query-gpu='\n", - " 'timestamp,'\n", - " 'name,'\n", - " 'utilization.gpu,'\n", - " 'utilization.memory,'\n", - " 'memory.used,'\n", - " 'memory.free,'\n", - " ,\n", - " '--format=csv'\n", - " ],\n", - " stdout=subprocess.PIPE).stdout.decode('utf-8')\n", - "\n", - " header, rec = outv.split('\\n')[:-1]\n", - " return pd.DataFrame({' '.join(k.strip().split('.')).capitalize():v for k,v in zip(header.split(','), rec.split(','))}, index=[0]).T\n", - "\n", - "gpu_info()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "kelHR9VM1-hg", - "tags": [] - }, - "outputs": [], - "source": [ - "#@title Install Dependencies\n", - "\n", - "# @markdown Your runtime will automatically restart after running this cell.\n", - "# @markdown You should only need to run this cell when setting up a new runtime. After future runtime restarts,\n", - "# @markdown you should be able to skip this cell.\n", - "\n", - "import warnings\n", - "\n", - "probably_using_colab = False\n", - "try:\n", - " import google\n", - " probably_using_colab = True\n", - "except ImportError:\n", - " warnings.warn(\"Unable to import `google`, assuming this means we're using a local runtime\")\n", - "\n", - "# @markdown Not recommended for colab users. This notebook is currently configured to only make this\n", - "# @markdown option available for local install.\n", - "use_xformers = False\n", - "\n", - "try:\n", - " import keyframed\n", - "except ImportError:\n", - " if probably_using_colab:\n", - " !pip install ftfy einops braceexpand requests transformers clip open_clip_torch omegaconf pytorch-lightning kornia k-diffusion ninja omegaconf\n", - " !pip install -U git+https://github.com/huggingface/huggingface_hub\n", - " !pip install napm keyframed\n", - " else:\n", - " !pip install -r klmc2/requirements.txt\n", - " if use_xformers:\n", - " !pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers\n", - "\n", - " exit() # restarts the runtime" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "fJZtXShcPXx5", - "tags": [] - }, - "outputs": [], - "source": [ - "# @markdown # Setup Workspace { display-mode: \"form\" }\n", - "\n", - "###################\n", - "# Setup Workspace #\n", - "###################\n", - "\n", - "import os\n", - "from pathlib import Path\n", - "import warnings\n", - "\n", - "probably_using_colab = False\n", - "try:\n", - " import google\n", - " if Path('/content').exists():\n", - " probably_using_colab = True\n", - " print(\"looks like we're in colab\")\n", - " else:\n", - " print(\"looks like we're not in colab\")\n", - "except ImportError:\n", - " warnings.warn(\"Unable to import `google`, assuming this means we're using a local runtime\")\n", - "\n", - "\n", - "mount_gdrive = True # @param {type:'boolean'}\n", - "\n", - "# defaults\n", - "outdir = Path('./frames')\n", - "if not os.environ.get('XDG_CACHE_HOME'):\n", - " os.environ['XDG_CACHE_HOME'] = str(Path('~/.cache').expanduser())\n", - "\n", - "if mount_gdrive and probably_using_colab:\n", - " from google.colab import drive\n", - " drive.mount('/content/drive')\n", - " Path('/content/drive/MyDrive/AI/models/.cache/').mkdir(parents=True, exist_ok=True) \n", - " os.environ['XDG_CACHE_HOME']='/content/drive/MyDrive/AI/models/.cache'\n", - " outdir = Path('/content/drive/MyDrive/AI/klmc2/frames/')\n", - "\n", - "# make sure the paths we need exist\n", - "outdir.mkdir(parents=True, exist_ok=True)\n", - "debug_dir = outdir.parent / 'debug_frames'\n", - "debug_dir.mkdir(parents=True, exist_ok=True)\n", - "\n", - "\n", - "os.environ['NAPM_PATH'] = str( Path(os.environ['XDG_CACHE_HOME']) / 'napm' )\n", - "Path(os.environ['NAPM_PATH']).mkdir(parents=True, exist_ok=True)\n", - "\n", - "\n", - "import napm\n", - "\n", - "url = 'https://github.com/Stability-AI/stablediffusion'\n", - "napm.pseudoinstall_git_repo(url, add_install_dir_to_path=True)\n", - "\n", - "\n", - "##### Moved from model loading cell\n", - "\n", - "if probably_using_colab:\n", - " models_path = \"/content/models\" #@param {type:\"string\"}\n", - "else:\n", - " models_path = os.environ['XDG_CACHE_HOME']\n", - "\n", - "if mount_gdrive and probably_using_colab:\n", - " models_path_gdrive = \"/content/drive/MyDrive/AI/models\" #@param {type:\"string\"}\n", - " models_path = models_path_gdrive\n", - "\n", - "if not Path(models_path).exists():\n", - " Path(models_path).mkdir(parents=True, exist_ok=True)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "y2jXKIf2ZkT8", - "tags": [] - }, - "outputs": [], - "source": [ - "# @markdown # Imports and Definitions { display-mode: \"form\" }\n", - "\n", - "###########\n", - "# imports #\n", - "###########\n", - "\n", - "# importing napm puts the stable diffusion repo on the PATH, which is where `ldm` imports from\n", - "import napm\n", - "from ldm.util import instantiate_from_config\n", - "\n", - "from base64 import b64encode\n", - "from collections import defaultdict\n", - "from concurrent import futures\n", - "import math\n", - "from pathlib import Path\n", - "import random\n", - "import re\n", - "import requests\n", - "from requests.exceptions import HTTPError\n", - "import sys\n", - "import time\n", - "from urllib.parse import urlparse\n", - "import warnings\n", - "\n", - "import functorch\n", - "import huggingface_hub\n", - "from IPython.display import display, Video, HTML\n", - "import k_diffusion as K\n", - "from keyframed import Curve, ParameterGroup, SmoothCurve\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np \n", - "from omegaconf import OmegaConf\n", - "import PIL\n", - "from PIL import Image\n", - "import torch\n", - "from torch import nn\n", - "from tqdm.auto import tqdm, trange\n", - "\n", - "from loguru import logger\n", - "import time\n", - "from natsort import natsorted\n", - "\n", - "\n", - "#########################\n", - "# Define useful globals #\n", - "#########################\n", - "\n", - "cpu = torch.device(\"cpu\")\n", - "device = torch.device(\"cuda\")\n", - "\n", - "\n", - "##############################\n", - "# Define necessary functions #\n", - "##############################\n", - " \n", - "import matplotlib.image\n", - "import numpy as np\n", - " \n", - "def get_latest_frame(i=None, latest_frame_fpath=None):\n", - " latest_frame = None\n", - " if latest_frame_fpath is not None:\n", - " latest_frame = latest_frame_fpath\n", - " \n", - " if (latest_frame is None) and (i is None):\n", - " frames = Path('frames').glob(\"*.png\")\n", - " #logger.debug(f\"frames: {len(frames)}\")\n", - " latest_frame = natsort.sort(frames)[-1]\n", - " i = re.findall('out_([0-9]+).png', latest_frame.name)\n", - " else:\n", - " latest_frame = Path('frames') / f\"out_{i:05}.png\"\n", - " logger.debug(f'-latest_frame: {latest_frame}')\n", - " #return Image.open(latest_frame)\n", - " img = matplotlib.image.imread(latest_frame)\n", - " return np.flip(img, axis=0) # up/down\n", - "\n", - "def plot_prompts(prompts=None, n=1000, settings=None, **kargs):\n", - " if prompts is not None:\n", - " for prompt in prompts:\n", - " prompt.weight.plot(n=n, **kargs)\n", - "\n", - "def plot_param(param, settings=None, prompts=None, n=1000, **kargs):\n", - " settings.parameters[param].plot(n=n, **kargs)\n", - " \n", - "# move imports up\n", - "import base64\n", - "from io import BytesIO\n", - "from functools import partial\n", - " \n", - "@logger.catch\n", - "def write_debug_frame_at_(\n", - " i=None,\n", - " n=300, \n", - " prompts=None, \n", - " stuff_to_plot=['prompts'], \n", - " latest_frame_fpath=None,\n", - " pil_image=None,\n", - " settings=None,\n", - "):\n", - " plotting_funcs = {\n", - " 'prompts': plot_prompts,\n", - " 'g': partial(plot_param, param='g'),\n", - " 'h': partial(plot_param, param='h'),\n", - " 'sigma': partial(plot_param, param='sigma'),\n", - " 'gamma': partial(plot_param, param='gamma'),\n", - " 'alpha': partial(plot_param, param='alpha'),\n", - " 'tau': partial(plot_param, param='tau'),\n", - " }\n", - " \n", - " # i feel like this line of code justifies the silly variable name\n", - " if not stuff_to_plot:\n", - " return\n", - " \n", - " #stuff_to_plot = []\n", - " \n", - " test_im = pil_image\n", - " if pil_image is None:\n", - " test_im = get_latest_frame(i, latest_frame_fpath)\n", - "\n", - " fig = plt.figure()\n", - " #axsRight = fig.subplots(3, 1, sharex=True)\n", - " #ax = axsRight[0]\n", - " ax_objs = fig.subplots(len(stuff_to_plot), 1, sharex=True)\n", - " \n", - " #width, height = test_im.size\n", - " height, width = test_im.size\n", - " fig.set_size_inches(height/fig.dpi, width/fig.dpi )\n", - " \n", - " buffer = BytesIO()\n", - " for j, category in enumerate(stuff_to_plot):\n", - " ax = ax_objs\n", - " if len(stuff_to_plot) > 1:\n", - " ax = ax_objs[j]\n", - " plt.sca(ax)\n", - " plt.tight_layout()\n", - " plt.axis('off')\n", - " \n", - " plotting_funcs[category](prompts=prompts, settings=settings, n=n, zorder=1)\n", - " plt.axvline(x=i)\n", - " \n", - " \n", - "\n", - " #plt.margins(0)\n", - " fig.savefig(buffer, transparent=True) \n", - " plt.close()\n", - "\n", - " buffer.seek(0)\n", - " plot_pil = Image.open(buffer)\n", - " #buffer.close() # throws error here\n", - "\n", - " #debug_im_path = Path('debug_frames') / f\"{category}_out_{i:05}.png\"\n", - " #debug_im_path = Path('debug_frames') / f\"debug_out_{i:05}.png\"\n", - " debug_im_path = debug_dir / f\"debug_out_{i:05}.png\"\n", - " test_im = test_im.convert('RGBA')\n", - " test_im.paste(plot_pil, (0,0), plot_pil)\n", - " test_im.save(debug_im_path)\n", - " #display(test_im) # maybe?\n", - " buffer.close() # I guess?\n", - " \n", - " return test_im, plot_pil\n", - "\n", - "##############################\n", - "\n", - "class Prompt:\n", - " def __init__(\n", - " self,\n", - " text,\n", - " weight_schedule,\n", - " ):\n", - " c = sd_model.get_learned_conditioning([text])\n", - " self.text=text\n", - " self.encoded=c\n", - " self.weight = SmoothCurve(weight_schedule)\n", - "\n", - "\n", - "def handle_chigozienri_curve_format(value_string):\n", - " if value_string.startswith('(') and value_string.endswith(')'):\n", - " value_string = value_string[1:-1]\n", - " return value_string\n", - "\n", - "def parse_curve_string(txt, f=float):\n", - " schedule = {}\n", - " for tokens in txt.split(','):\n", - " k,v = tokens.split(':')\n", - " v = handle_chigozienri_curve_format(v)\n", - " schedule[int(k)] = f(v)\n", - " return schedule\n", - "\n", - "def parse_curvable_string(param, is_int=False):\n", - " if isinstance(param, dict):\n", - " return param\n", - " f = float\n", - " if is_int:\n", - " f = int\n", - " try:\n", - " return f(param)\n", - " except ValueError:\n", - " return parse_curve_string(txt=param, f=f)\n", - "\n", - "##################\n", - "\n", - "def show_video(video_path, video_width=512):\n", - " return display(Video(video_path, width=video_width))\n", - "\n", - "if probably_using_colab:\n", - " def show_video(video_path, video_width=512):\n", - " video_file = open(video_path, \"r+b\").read()\n", - " video_url = f\"data:video/mp4;base64,{b64encode(video_file).decode()}\"\n", - " return display(HTML(f\"\"\"\"\"\"))\n", - "\n", - "##################\n", - "\n", - "class NormalizingCFGDenoiser(nn.Module):\n", - " def __init__(self, model, g):\n", - " super().__init__()\n", - " self.inner_model = model\n", - " self.g = g\n", - " self.eps_norms = defaultdict(lambda: (0, 0))\n", - "\n", - " def mean_sq(self, x):\n", - " return x.pow(2).flatten(1).mean(1)\n", - "\n", - " @torch.no_grad()\n", - " def update_eps_norm(self, eps, sigma):\n", - " sigma = sigma[0].item()\n", - " eps_norm = self.mean_sq(eps).mean()\n", - " eps_norm_avg, count = self.eps_norms[sigma]\n", - " eps_norm_avg = eps_norm_avg * count / (count + 1) + eps_norm / (count + 1)\n", - " self.eps_norms[sigma] = (eps_norm_avg, count + 1)\n", - " return eps_norm_avg\n", - "\n", - " def forward(self, x, sigma, uncond, cond, g):\n", - " x_in = torch.cat([x] * 2)\n", - " sigma_in = torch.cat([sigma] * 2)\n", - " cond_in = torch.cat([uncond, cond])\n", - "\n", - " denoised = self.inner_model(x_in, sigma_in, cond=cond_in)\n", - " eps = K.sampling.to_d(x_in, sigma_in, denoised)\n", - " eps_uc, eps_c = eps.chunk(2)\n", - " eps_norm = self.update_eps_norm(eps, sigma).sqrt()\n", - " c = eps_c - eps_uc\n", - " cond_scale = g * eps_norm / self.mean_sq(c).sqrt()\n", - " eps_final = eps_uc + c * K.utils.append_dims(cond_scale, x.ndim)\n", - " return x - eps_final * K.utils.append_dims(sigma, eps.ndim)\n", - "\n", - "#########################\n", - "\n", - "def write_klmc2_state(**state):\n", - " st = time.time()\n", - " obj = {}\n", - " for k,v in state.items():\n", - " try:\n", - " v = v.clone().detach().cpu()\n", - " except AttributeError:\n", - " # if it doesn't have a detach method, we don't need to worry about any preprocessing\n", - " pass\n", - " obj[k] = v\n", - "\n", - " checkpoint_fpath = Path(outdir) / f\"klmc2_state_{state.get('i',0):05}.ckpt\"\n", - " with open(checkpoint_fpath, 'wb') as f:\n", - " torch.save(obj, f=f)\n", - " et = time.time()\n", - " #logger.debug(f\"checkpointing: {et-st}\")\n", - " # to do: move to callback? thread executor, anyway\n", - "\n", - "def read_klmc2_state(root=outdir, latest_frame=-1):\n", - " state = {}\n", - " checkpoints = [str(p) for p in Path(root).glob(\"*.ckpt\")]\n", - " if not checkpoints:\n", - " return None\n", - " checkpoints = natsorted(checkpoints)\n", - " if latest_frame < 0:\n", - " ckpt_fpath = checkpoints[-1]\n", - " else:\n", - " for fname in checkpoints:\n", - " frame_id = re.findall(r'([0-9]+).ckpt', fname)[0]\n", - " if int(frame_id) <= latest_frame:\n", - " ckpt_fpath = fname\n", - " else:\n", - " break\n", - " logger.debug(ckpt_fpath)\n", - " with open(ckpt_fpath,'rb') as f:\n", - " state = torch.load(f=f,map_location='cuda')\n", - " return state\n", - "\n", - "def load_init_image(init_image, height, width):\n", - " if not Path(init_image).exists():\n", - " raise FileNotFoundError(f\"Unable to locate init image from path: {init_image}\")\n", - " \n", - " \n", - " from PIL import Image\n", - " import numpy as np\n", - "\n", - " init_im_pil = Image.open(init_image)\n", - "\n", - " #x_pil = init_im_pil.resize([512,512])\n", - " x_pil = init_im_pil.resize([height,width])\n", - " x_np = np.array(x_pil.convert('RGB')).astype(np.float16) / 255.0\n", - " x = x_np[None].transpose(0, 3, 1, 2)\n", - " x = 2.*x - 1.\n", - " x = torch.from_numpy(x).to('cuda')\n", - " return x\n", - "\n", - "def save_image_fn(image, name, i, n, prompts=None, settings=None, stuff_to_plot=['prompts']):\n", - " pil_image = K.utils.to_pil_image(image)\n", - " if i % 10 == 0 or i == n - 1:\n", - " print(f'\\nIteration {i}/{n}:')\n", - " display(pil_image)\n", - " if i == n - 1:\n", - " print('\\nDone!')\n", - " pil_image.save(name)\n", - " if stuff_to_plot:\n", - " #logger.debug(stuff_to_plot)\n", - " #write_debug_frame_at_(i, prompts=prompts)\n", - " debug_frame, debug_plot = write_debug_frame_at_(i=i,n=n, prompts=prompts, settings=settings, stuff_to_plot=stuff_to_plot, pil_image=pil_image)\n", - " if i % 10 == 0 or i == n - 1:\n", - " #display(debug_frame)\n", - " display(debug_plot)\n", - "\n", - "###############################\n", - "\n", - "@torch.no_grad()\n", - "def sample_mcmc_klmc2(\n", - " sd_model, \n", - " init_image,\n", - " height:int,\n", - " width:int,\n", - " n:int, \n", - " hvp_method:str='reverse', \n", - " prompts:list=None,\n", - " settings:ParameterGroup=None,\n", - " resume:bool = False,\n", - " resume_from:int=-1,\n", - " img_init_steps:int=None,\n", - " stuff_to_plot:list=None,\n", - " checkpoint_every:int=10,\n", - "):\n", - "\n", - " if stuff_to_plot is None:\n", - " stuff_to_plot = ['prompts','h']\n", - " \n", - " torch.cuda.empty_cache()\n", - "\n", - " wrappers = {'eps': K.external.CompVisDenoiser, 'v': K.external.CompVisVDenoiser}\n", - " g = settings[0]['g']\n", - "\n", - " model_wrap = wrappers[sd_model.parameterization](sd_model)\n", - " model_wrap_cfg = NormalizingCFGDenoiser(model_wrap, g)\n", - " sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()\n", - " model = model_wrap_cfg\n", - "\n", - " uc = sd_model.get_learned_conditioning([''])\n", - " extra_args = {\n", - " 'uncond': uc, \n", - " #'cond': prompts[0].encoded,\n", - " 'g': settings[0]['g']\n", - " }\n", - "\n", - " sigma = settings[0]['sigma']\n", - "\n", - " with torch.cuda.amp.autocast(), futures.ThreadPoolExecutor() as ex:\n", - " def callback(info):\n", - " i = info['i']\n", - " rgb = sd_model.decode_first_stage(info['denoised'] )\n", - " ex.submit(save_image_fn, image=rgb, name=(outdir / f'out_{i:05}.png'), i=i, n=n, prompts=prompts, settings=settings, stuff_to_plot=stuff_to_plot)\n", - "\n", - " # Initialize the chain\n", - " print('Initializing the chain...')\n", - "\n", - " # to do: if resuming, generating this init image is unnecessary\n", - " x = None\n", - " if init_image:\n", - " print(\"loading init image\")\n", - " x = load_init_image(init_image, height, width)\n", - " # convert RGB to latent\n", - " x = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(x))\n", - " print(\"init image loaded.\")\n", - "\n", - " print('Actually doing the sampling...')\n", - "\n", - " i_resume=0\n", - " v = None\n", - " if resume:\n", - " state = read_klmc2_state(latest_frame=resume_from)\n", - " if state:\n", - " x, v, i_resume = state['x'], state['v'], state['i']\n", - " # to do: resumption of settings\n", - " settings_i = state['settings_i']\n", - " settings[i]['h'] = settings_i['h']\n", - " settings[i]['gamma'] = settings_i['gamma']\n", - " settings[i]['alpha'] = settings_i['alpha']\n", - " settings[i]['tau'] = settings_i['tau']\n", - " settings[i]['g'] = settings_i['g']\n", - " settings[i]['sigma'] = settings_i['sigma']\n", - " settings[i]['steps'] = settings_i['steps']\n", - " \n", - " # to do: use multicond for init image\n", - " # we want this test after resumption if resuming\n", - " if x is None:\n", - " print(\"No init image provided, generating a random init image\")\n", - " extra_args['cond'] = prompts[0].encoded\n", - " h=height//8\n", - " w=width//8\n", - " x = torch.randn([1, 4, h, w], device=device) * sigma_max\n", - " sigmas_pre = K.sampling.get_sigmas_karras(img_init_steps, sigma, sigma_max, device=x.device)[:-1]\n", - " x = K.sampling.sample_dpmpp_sde(model_wrap_cfg, x, sigmas_pre, extra_args=extra_args)\n", - "\n", - " # if not resuming, randomly initialize momentum\n", - " # this needs to be *after* generating X if we're going to...\n", - " if v is None:\n", - " v = torch.randn_like(x) * sigma\n", - "\n", - " # main sampling loop\n", - " for i in trange(n):\n", - " # fast-forward loop to resumption index\n", - " if resume and i < i_resume:\n", - " continue\n", - " # if resume and (i == i_resume):\n", - " # # should these values be written into settings[i]?\n", - " # h = settings_i['h']\n", - " # gamma = settings_i['gamma']\n", - " # alpha = settings_i['alpha']\n", - " # tau = settings_i['tau']\n", - " # g = settings_i['g']\n", - " # sigma = settings_i['sigma']\n", - " # steps = settings_i['steps']\n", - " # else:\n", - " h = settings[i]['h']\n", - " gamma = settings[i]['gamma']\n", - " alpha = settings[i]['alpha']\n", - " tau = settings[i]['tau']\n", - " g = settings[i]['g']\n", - " sigma = settings[i]['sigma']\n", - " steps = settings[i]['steps']\n", - "\n", - " h = torch.tensor(h, device=x.device)\n", - " gamma = torch.tensor(gamma, device=x.device)\n", - " alpha = torch.tensor(alpha, device=x.device)\n", - " tau = torch.tensor(tau, device=x.device)\n", - " sigma = torch.tensor(sigma, device=x.device)\n", - " steps = int(steps)\n", - " \n", - " sigmas = K.sampling.get_sigmas_karras(steps, sigma_min, sigma.item(), device=x.device)[:-1]\n", - "\n", - " x, v, grad = klmc2_step(\n", - " model,\n", - " prompts,\n", - " x,\n", - " v,\n", - " h,\n", - " gamma,\n", - " alpha,\n", - " tau,\n", - " g,\n", - " sigma,\n", - " sigmas,\n", - " steps,\n", - " hvp_method,\n", - " i,\n", - " callback,\n", - " extra_args,\n", - " )\n", - "\n", - " save_checkpoint = (i % checkpoint_every) == 0\n", - " if save_checkpoint:\n", - " settings_i = settings[i]\n", - " ex.submit(write_klmc2_state, v=v, x=x, i=i, settings_i=settings_i)\n", - " logger.debug(settings[i])\n", - "\n", - "\n", - "def hvp_fn_forward_functorch(model, x, sigma, v, alpha, extra_args):\n", - " def grad_fn(x, sigma):\n", - " s_in = x.new_ones([x.shape[0]])\n", - " denoised = model(x, sigma * s_in, **extra_args)\n", - " return (x - denoised) + alpha * x\n", - " jvp_fn = lambda v: functorch.jvp(grad_fn, (x, sigma), (v, torch.zeros_like(sigma)))\n", - " grad, jvp_out = functorch.vmap(jvp_fn)(v)\n", - " return grad[0], jvp_out\n", - "\n", - "def hvp_fn_reverse(model, x, sigma, v, alpha, extra_args):\n", - " def grad_fn(x, sigma):\n", - " s_in = x.new_ones([x.shape[0]])\n", - " denoised = model(x, sigma * s_in, **extra_args)\n", - " return (x - denoised) + alpha * x\n", - " vjps = []\n", - " with torch.enable_grad():\n", - " x_ = x.clone().requires_grad_()\n", - " grad = grad_fn(x_, sigma)\n", - " for k, item in enumerate(v):\n", - " vjp_out = torch.autograd.grad(grad, x_, item, retain_graph=k < len(v) - 1)[0]\n", - " vjps.append(vjp_out)\n", - " return grad, torch.stack(vjps)\n", - "\n", - "def hvp_fn_zero(model, x, sigma, v, alpha, extra_args):\n", - " def grad_fn(x, sigma):\n", - " s_in = x.new_ones([x.shape[0]])\n", - " denoised = model(x, sigma * s_in, **extra_args)\n", - " return (x - denoised) + alpha * x\n", - " return grad_fn(x, sigma), torch.zeros_like(v)\n", - "\n", - "def hvp_fn_fake(model, x, sigma, v, alpha, extra_args):\n", - " def grad_fn(x, sigma):\n", - " s_in = x.new_ones([x.shape[0]])\n", - " denoised = model(x, sigma * s_in, **extra_args)\n", - " return (x - denoised) + alpha * x\n", - " return grad_fn(x, sigma), (1 + alpha) * v\n", - "\n", - "\n", - "def multicond_hvp(model, x, sigma, v, alpha, extra_args, prompts, hvp_fn, i):\n", - "\n", - " # loop over prompts and aggregate gradients for multicond\n", - " grad = torch.zeros_like(x)\n", - " h2_v = torch.zeros_like(x)\n", - " h2_noise_v2 = torch.zeros_like(x)\n", - " h2_noise_x2 = torch.zeros_like(x)\n", - " wt_norm = 0\n", - " for prompt in prompts:\n", - " wt = prompt.weight[i]\n", - " if wt == 0:\n", - " continue\n", - " wt_norm += wt\n", - " wt = torch.tensor(wt, device=x.device)\n", - " extra_args['cond'] = prompt.encoded\n", - "\n", - " # Estimate gradient and hessian\n", - " grad_, (h2_v_, h2_noise_v2_, h2_noise_x2_) = hvp_fn(\n", - " model=model,\n", - " x=x, \n", - " sigma=sigma, \n", - " v=v,\n", - " alpha=alpha,\n", - " extra_args=extra_args,\n", - " )\n", - "\n", - " grad = grad + grad_ * wt \n", - " h2_v = h2_v + h2_v_ * wt\n", - " h2_noise_v2 = h2_noise_v2 + h2_noise_v2_ * wt\n", - " h2_noise_x2 = h2_noise_x2 + h2_noise_x2_ * wt\n", - "\n", - " # Normalize gradient to magnitude it'd have if just single prompt w/ wt=1.\n", - " # simplifies multicond w/o deep frying image or adding hyperparams\n", - " grad = grad / wt_norm \n", - " h2_v = h2_v / wt_norm\n", - " h2_noise_v2 = h2_noise_v2 / wt_norm\n", - " h2_noise_x2 = h2_noise_x2 / wt_norm\n", - "\n", - " return grad, h2_v, h2_noise_v2, h2_noise_v2, h2_noise_x2\n", - "\n", - "\n", - "\n", - "\n", - "def klmc2_step(\n", - " model,\n", - " prompts,\n", - " x,\n", - " v,\n", - " h,\n", - " gamma,\n", - " alpha,\n", - " tau,\n", - " g,\n", - " sigma,\n", - " sigmas,\n", - " steps,\n", - " hvp_method,\n", - " i,\n", - " callback,\n", - " extra_args,\n", - " ):\n", - "\n", - " #s_in = x.new_ones([x.shape[0]])\n", - "\n", - " # Model helper functions\n", - "\n", - " hvp_fns = {'forward-functorch': hvp_fn_forward_functorch,\n", - " 'reverse': hvp_fn_reverse,\n", - " 'zero': hvp_fn_zero,\n", - " 'fake': hvp_fn_fake}\n", - "\n", - " hvp_fn = hvp_fns[hvp_method]\n", - "\n", - " # KLMC2 helper functions\n", - " def psi_0(gamma, t):\n", - " return torch.exp(-gamma * t)\n", - "\n", - " def psi_1(gamma, t):\n", - " return -torch.expm1(-gamma * t) / gamma\n", - "\n", - " def psi_2(gamma, t):\n", - " return (torch.expm1(-gamma * t) + gamma * t) / gamma ** 2\n", - "\n", - " def phi_2(gamma, t_):\n", - " t = t_.double()\n", - " out = (torch.exp(-gamma * t) * (torch.expm1(gamma * t) - gamma * t)) / gamma ** 2\n", - " return out.to(t_)\n", - "\n", - " def phi_3(gamma, t_):\n", - " t = t_.double()\n", - " out = (torch.exp(-gamma * t) * (2 + gamma * t + torch.exp(gamma * t) * (gamma * t - 2))) / gamma ** 3\n", - " return out.to(t_)\n", - "\n", - "\n", - " # Compute model outputs and sample noise\n", - " x_trapz = torch.linspace(0, h, 1001, device=x.device)\n", - " y_trapz = [fun(gamma, x_trapz) for fun in (psi_0, psi_1, phi_2, phi_3)]\n", - " noise_cov = torch.tensor([[torch.trapz(y_trapz[i] * y_trapz[j], x=x_trapz) for j in range(4)] for i in range(4)], device=x.device)\n", - " noise_v, noise_x, noise_v2, noise_x2 = torch.distributions.MultivariateNormal(x.new_zeros([4]), noise_cov).sample(x.shape).unbind(-1)\n", - "\n", - " extra_args['g']=g\n", - "\n", - " # compute derivatives, multicond wrapper loops over prompts and averages derivatives\n", - " grad, h2_v, h2_noise_v2, h2_noise_v2, h2_noise_x2 = multicond_hvp(\n", - " model=model, \n", - " x=x, \n", - " sigma=sigma, \n", - " v=torch.stack([v, noise_v2, noise_x2]), # need a \"dummy\" v for init image generation\n", - " alpha=alpha, \n", - " extra_args=extra_args, \n", - " prompts=prompts, \n", - " hvp_fn=hvp_fn,\n", - " i=i,\n", - " )\n", - "\n", - " # DPM-Solver++(2M) refinement steps\n", - " x_refine = x\n", - " use_dpm = True\n", - " old_denoised = None\n", - " for j in range(len(sigmas) - 1):\n", - " if j == 0:\n", - " denoised = x_refine - grad\n", - " else:\n", - " s_in = x.new_ones([x.shape[0]])\n", - " denoised = model(x_refine, sigmas[j] * s_in, **extra_args)\n", - " dt_ode = sigmas[j + 1] - sigmas[j]\n", - " if not use_dpm or old_denoised is None or sigmas[j + 1] == 0:\n", - " eps = K.sampling.to_d(x_refine, sigmas[j], denoised)\n", - " x_refine = x_refine + eps * dt_ode\n", - " else:\n", - " h_ode = sigmas[j].log() - sigmas[j + 1].log()\n", - " h_last = sigmas[j - 1].log() - sigmas[j].log()\n", - " fac = h_ode / (2 * h_last)\n", - " denoised_d = (1 + fac) * denoised - fac * old_denoised\n", - " eps = K.sampling.to_d(x_refine, sigmas[j], denoised_d)\n", - " x_refine = x_refine + eps * dt_ode\n", - " old_denoised = denoised\n", - " if callback is not None:\n", - " callback({'i': i, 'denoised': x_refine})\n", - "\n", - " # Update the chain\n", - " noise_std = (2 * gamma * tau * sigma ** 2).sqrt()\n", - " v_next = 0 + psi_0(gamma, h) * v - psi_1(gamma, h) * grad - phi_2(gamma, h) * h2_v + noise_std * (noise_v - h2_noise_v2)\n", - " x_next = x + psi_1(gamma, h) * v - psi_2(gamma, h) * grad - phi_3(gamma, h) * h2_v + noise_std * (noise_x - h2_noise_x2)\n", - " v, x = v_next, x_next\n", - "\n", - " return x, v, grad " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "yt3d1hww17ST", - "tags": [] - }, - "outputs": [], - "source": [ - "#@markdown **Select and Load Model**\n", - "\n", - "## TO DO:\n", - "## - if local, try to load model from ~/.cache/huggingface/diffusers\n", - "\n", - "# modified from:\n", - "# https://github.com/deforum/stable-diffusion/blob/main/Deforum_Stable_Diffusion.ipynb\n", - "\n", - "import napm\n", - "from ldm.util import instantiate_from_config\n", - "\n", - "\n", - "model_config = \"v1-inference.yaml\" #@param [\"custom\",\"v1-inference.yaml\"]\n", - "model_checkpoint = \"sd-v1-4.ckpt\" #@param [\"custom\",\"sd-v1-4-full-ema.ckpt\",\"sd-v1-4.ckpt\",\"sd-v1-3-full-ema.ckpt\",\"sd-v1-3.ckpt\",\"sd-v1-2-full-ema.ckpt\",\"sd-v1-2.ckpt\",\"sd-v1-1-full-ema.ckpt\",\"sd-v1-1.ckpt\", \"robo-diffusion-v1.ckpt\",\"waifu-diffusion-v1-3.ckpt\"]\n", - "if model_checkpoint == \"waifu-diffusion-v1-3.ckpt\":\n", - " model_checkpoint = \"model-epoch05-float16.ckpt\"\n", - "custom_config_path = \"\" #@param {type:\"string\"}\n", - "custom_checkpoint_path = \"\" #@param {type:\"string\"}\n", - "\n", - "half_precision = True # check\n", - "check_sha256 = False #@param {type:\"boolean\"}\n", - "\n", - "model_map = {\n", - " \"sd-v1-4-full-ema.ckpt\": {\n", - " 'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a',\n", - " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-4-full-ema.ckpt',\n", - " 'requires_login': True,\n", - " },\n", - " \"sd-v1-4.ckpt\": {\n", - " 'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556',\n", - " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt',\n", - " 'requires_login': True,\n", - " },\n", - " \"sd-v1-3-full-ema.ckpt\": {\n", - " 'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca',\n", - " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/blob/main/sd-v1-3-full-ema.ckpt',\n", - " 'requires_login': True,\n", - " },\n", - " \"sd-v1-3.ckpt\": {\n", - " 'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f',\n", - " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/resolve/main/sd-v1-3.ckpt',\n", - " 'requires_login': True,\n", - " },\n", - " \"sd-v1-2-full-ema.ckpt\": {\n", - " 'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a',\n", - " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-2-full-ema.ckpt',\n", - " 'requires_login': True,\n", - " },\n", - " \"sd-v1-2.ckpt\": {\n", - " 'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d',\n", - " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/resolve/main/sd-v1-2.ckpt',\n", - " 'requires_login': True,\n", - " },\n", - " \"sd-v1-1-full-ema.ckpt\": {\n", - " 'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829',\n", - " 'url':'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1-full-ema.ckpt',\n", - " 'requires_login': True,\n", - " },\n", - " \"sd-v1-1.ckpt\": {\n", - " 'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea',\n", - " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1.ckpt',\n", - " 'requires_login': True,\n", - " },\n", - " \"robo-diffusion-v1.ckpt\": {\n", - " 'sha256': '244dbe0dcb55c761bde9c2ac0e9b46cc9705ebfe5f1f3a7cc46251573ea14e16',\n", - " 'url': 'https://huggingface.co/nousr/robo-diffusion/resolve/main/models/robo-diffusion-v1.ckpt',\n", - " 'requires_login': False,\n", - " },\n", - " \"model-epoch05-float16.ckpt\": {\n", - " 'sha256': '26cf2a2e30095926bb9fd9de0c83f47adc0b442dbfdc3d667d43778e8b70bece',\n", - " 'url': 'https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/model-epoch05-float16.ckpt',\n", - " 'requires_login': False,\n", - " },\n", - "}\n", - "\n", - "# config path\n", - "ckpt_config_path = custom_config_path if model_config == \"custom\" else os.path.join(models_path, model_config)\n", - "if os.path.exists(ckpt_config_path):\n", - " print(f\"{ckpt_config_path} exists\")\n", - "else:\n", - " #ckpt_config_path = \"./stable-diffusion/configs/stable-diffusion/v1-inference.yaml\"\n", - " ckpt_config_path = \"./v1-inference.yaml\"\n", - " if not Path(ckpt_config_path).exists():\n", - " !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml\n", - " \n", - "print(f\"Using config: {ckpt_config_path}\")\n", - "\n", - "# checkpoint path or download\n", - "ckpt_path = custom_checkpoint_path if model_checkpoint == \"custom\" else os.path.join(models_path, model_checkpoint)\n", - "ckpt_valid = True\n", - "if os.path.exists(ckpt_path):\n", - " print(f\"{ckpt_path} exists\")\n", - "elif 'url' in model_map[model_checkpoint]:\n", - " url = model_map[model_checkpoint]['url']\n", - "\n", - " # CLI dialogue to authenticate download\n", - " if model_map[model_checkpoint]['requires_login']:\n", - " print(\"This model requires an authentication token\")\n", - " print(\"Please ensure you have accepted its terms of service before continuing.\")\n", - "\n", - " username = input(\"What is your huggingface username?:\")\n", - " token = input(\"What is your huggingface token?:\")\n", - "\n", - " _, path = url.split(\"https://\")\n", - "\n", - " url = f\"https://{username}:{token}@{path}\"\n", - "\n", - " # contact server for model\n", - " print(f\"Attempting to download {model_checkpoint}...this may take a while\")\n", - " ckpt_request = requests.get(url)\n", - " request_status = ckpt_request.status_code\n", - "\n", - " # inform user of errors\n", - " if request_status == 403:\n", - " raise ConnectionRefusedError(\"You have not accepted the license for this model.\")\n", - " elif request_status == 404:\n", - " raise ConnectionError(\"Could not make contact with server\")\n", - " elif request_status != 200:\n", - " raise ConnectionError(f\"Some other error has ocurred - response code: {request_status}\")\n", - "\n", - " # write to model path\n", - " with open(os.path.join(models_path, model_checkpoint), 'wb') as model_file:\n", - " model_file.write(ckpt_request.content)\n", - "else:\n", - " print(f\"Please download model checkpoint and place in {os.path.join(models_path, model_checkpoint)}\")\n", - " ckpt_valid = False\n", - "\n", - "if check_sha256 and model_checkpoint != \"custom\" and ckpt_valid:\n", - " import hashlib\n", - " print(\"\\n...checking sha256\")\n", - " with open(ckpt_path, \"rb\") as f:\n", - " bytes = f.read() \n", - " hash = hashlib.sha256(bytes).hexdigest()\n", - " del bytes\n", - " if model_map[model_checkpoint][\"sha256\"] == hash:\n", - " print(\"hash is correct\\n\")\n", - " else:\n", - " print(\"hash in not correct\\n\")\n", - " ckpt_valid = False\n", - "\n", - "if ckpt_valid:\n", - " print(f\"Using ckpt: {ckpt_path}\")\n", - "\n", - "def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True):\n", - " map_location = \"cuda\" #@param [\"cpu\", \"cuda\"]\n", - " print(f\"Loading model from {ckpt}\")\n", - " pl_sd = torch.load(ckpt, map_location=map_location)\n", - " if \"global_step\" in pl_sd:\n", - " print(f\"Global Step: {pl_sd['global_step']}\")\n", - " sd = pl_sd[\"state_dict\"]\n", - " model = instantiate_from_config(config.model)\n", - " m, u = model.load_state_dict(sd, strict=False)\n", - " if len(m) > 0 and verbose:\n", - " print(\"missing keys:\")\n", - " print(m)\n", - " if len(u) > 0 and verbose:\n", - " print(\"unexpected keys:\")\n", - " print(u)\n", - "\n", - " if half_precision:\n", - " model = model.half().to(device)\n", - " else:\n", - " model = model.to(device)\n", - " model.eval()\n", - " return model\n", - "\n", - "if ckpt_valid:\n", - " local_config = OmegaConf.load(f\"{ckpt_config_path}\")\n", - " model = load_model_from_config(local_config, f\"{ckpt_path}\", half_precision=half_precision)\n", - " device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", - " model = model.to(device)\n", - "\n", - " # Disable checkpointing as it is not compatible with the method\n", - " for module in model.modules():\n", - " if hasattr(module, 'checkpoint'):\n", - " module.checkpoint = False\n", - " if hasattr(module, 'use_checkpoint'):\n", - " module.use_checkpoint = False\n", - "\n", - " sd_model=model\n", - "\n", - "####################################################################\n", - "\n", - "use_new_vae = True #@param {type:\"boolean\"}\n", - "\n", - "if use_new_vae:\n", - "\n", - " # from kat's notebook again\n", - "\n", - " def download_from_huggingface(repo, filename):\n", - " while True:\n", - " try:\n", - " return huggingface_hub.hf_hub_download(repo, filename)\n", - " except HTTPError as e:\n", - " if e.response.status_code == 401:\n", - " # Need to log into huggingface api\n", - " huggingface_hub.interpreter_login()\n", - " continue\n", - " elif e.response.status_code == 403:\n", - " # Need to do the click through license thing\n", - " print(f'Go here and agree to the click through license on your account: https://huggingface.co/{repo}')\n", - " input('Hit enter when ready:')\n", - " continue\n", - " else:\n", - " raise e\n", - "\n", - " vae_840k_model_path = download_from_huggingface(\"stabilityai/sd-vae-ft-mse-original\", \"vae-ft-mse-840000-ema-pruned.ckpt\")\n", - "\n", - " def load_model_from_config_kc(config, ckpt):\n", - " print(f\"Loading model from {ckpt}\")\n", - " pl_sd = torch.load(ckpt, map_location=\"cpu\")\n", - " sd = pl_sd[\"state_dict\"]\n", - " config = OmegaConf.load(config)\n", - "\n", - " try:\n", - " config['model']['params']['lossconfig']['target'] = \"torch.nn.Identity\"\n", - " print('Patched VAE config.')\n", - " except KeyError:\n", - " pass\n", - "\n", - " model = instantiate_from_config(config.model)\n", - " m, u = model.load_state_dict(sd, strict=False)\n", - " model = model.to(cpu).eval().requires_grad_(False)\n", - " return model\n", - "\n", - " vaemodel_yaml_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/models/first_stage_models/kl-f8/config.yaml\"\n", - " vaemodel_yaml_fname = 'config_vae_kl-f8.yaml'\n", - " vaemodel_yaml_fname_git = \"latent-diffusion/models/first_stage_models/kl-f8/config.yaml\"\n", - " if Path(vaemodel_yaml_fname_git).exists():\n", - " vae_model = load_model_from_config_kc(vaemodel_yaml_fname_git, vae_840k_model_path).half().to(device)\n", - " else:\n", - " if not Path(vaemodel_yaml_fname).exists():\n", - " !wget {vaemodel_yaml_url} -O {vaemodel_yaml_fname}\n", - " vae_model = load_model_from_config_kc(vaemodel_yaml_fname, vae_840k_model_path).half().to(device)\n", - "\n", - " del sd_model.first_stage_model\n", - " sd_model.first_stage_model = vae_model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "ZljSF1ePnBl4", - "tags": [] - }, - "outputs": [], - "source": [ - "# @title Settings\n", - "\n", - "# @markdown The number of frames to sample:\n", - "n = 300 # @param {type:\"integer\"}\n", - "\n", - "# @markdown height and width must be multiples of 8 (e.g. 256, 512, 768, 1024)\n", - "height = 512 # @param {type:\"integer\"}\n", - "\n", - "width = 512 # @param {type:\"integer\"}\n", - "\n", - "\n", - "# @markdown If seed is negative, a random seed will be used\n", - "seed = -1 # @param {type:\"number\"}\n", - "\n", - "init_image = \"\" # @param {type:'string'}\n", - "\n", - "# @markdown ---\n", - "\n", - "# @markdown Settings below this line can be parameterized using keyframe syntax: `\"time:weight, time:weight, ...\". \n", - "# @markdown Over spans where values of weights change, intermediate values will be interpolated using an \"s\" shaped curve.\n", - "# @markdown If a value for keyframe 0 is not specified, it is presumed to be `0:0`.\n", - "\n", - "# @markdown The strength of the conditioning on the prompt:\n", - "g=\"0:0.1\" # @param {type:\"string\"}\n", - "\n", - "# @markdown The noise level to sample at\n", - "# @markdown Ramp up from a tiny sigma if using init image, e.g. `0:0.25, 100:2, ...`\n", - "# @markdown NB: Turning sigma *up* mid generation seems to work fine, but turning sigma *down* mid generation tends to \"deep fry\" the outputs\n", - "sigma = \"1.25\" # @param {type:\"string\"}\n", - "\n", - "# @markdown Step size (range 0 to 1):\n", - "h = \"0:0.1, 30:0.1, 50:0.3, 70:0.1, 120:0.1, 140:.3, 160:.1, 210:.1, 230:.3, 250:.1\" # @param {type:\"string\"}\n", - "\n", - "# @markdown Friction (2 is critically damped, lower -> smoother animation):\n", - "gamma = \"1.1\" # @param {type:\"string\"}\n", - "\n", - "# @markdown Quadratic penalty (\"weight decay\") strength:\n", - "alpha = \"0.005\" # @param {type:\"string\"}\n", - "\n", - "# @markdown Temperature (adjustment to the amount of noise added per step):\n", - "tau = \"1.0\" # @param {type:\"string\"}\n", - "\n", - "# @markdown Denoising refinement steps:\n", - "refinement_steps = \"6\" # @param {type:\"string\"}\n", - "\n", - "# @markdown If an init image is not provided, this is how many steps will be used when generating an initial state:\n", - "img_init_steps = 15 # @param {type:\"number\"}\n", - "\n", - "# @markdown The HVP method:\n", - "# @markdown
`forward-functorch` and `reverse` provide real second derivatives. Compatibility, speed, and memory usage vary by model and xformers configuration.\n", - "# @markdown `fake` is very fast and low memory but inaccurate. `zero` (fallback to first order KLMC) is not recommended.\n", - "hvp_method = 'fake' # @param [\"forward-functorch\", \"reverse\", \"fake\", \"zero\"]\n", - "\n", - "checkpoint_every = 10 # @param {type:\"number\"}\n", - "\n", - "###########################\n", - "\n", - "assert (height % 8) == 0\n", - "assert (width % 8) == 0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "1pLTsdGBPXx6", - "tags": [] - }, - "outputs": [], - "source": [ - "#@title Prompts\n", - "\n", - "# [ \n", - "# [\"first prompt will be used to initialize the image\", {time:weight, time:weight...}], \n", - "# [\"more prompts if you want\", {...}], \n", - "# ...]\n", - "\n", - "# if a weight for time=0 isn't specified, the weight is assumed to be zero.\n", - "# if you want to \"fade in\" any prompts, it's best to have them start with a small but non-zero value, e.g. 0.001\n", - "\n", - "prompt_params = [\n", - " # # FIRST PROMPT INITIALIZES IMAGE\n", - " #[\"sweetest puppy, golden retriever\", {0:.5, 30:0.5, 100:0.001}],\n", - " #[\"sweet old dog, golden retriever\", {0:0.001, 30:0.001, 100:0.5}],\n", - " #[\"happiest pupper, cutest dog evar, golden retriever, incredibly photogenic dog\", {0:1}],\n", - "\n", - " # # the 'flowers prompts' below go with a particular 'h' setting in the next cell\n", - " [\"incredibly beautiful orchids, a bouquet of orchids\", {0:1, 35:1, 50:0}],\n", - " [\"incredibly beautiful roses, a bouquet of roses\", {0:0.001, 35:0.001, 50:1, 120:1, 140:0}],\n", - " [\"incredibly beautiful carnations, a bouquet of carnations\", {0:0.001, 120:0.001, 140:1, 220:1, 240:0}],\n", - " [\"incredibly beautiful carnations, a bouquet of sunflowers\", {0:0.001, 220:0.001, 240:1}],\n", - " \n", - " # negative prompts\n", - " [\"watermark text\", {0:-0.1} ],\n", - " [\"jpeg artifacts\", {0:-0.1} ],\n", - " [\"artist's signature\", {0:-0.1} ],\n", - " [\"istockphoto, gettyimages, watermarked image\", {0:-0.1} ],\n", - "]\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "hZ0lh-WkdB19", - "tags": [] - }, - "outputs": [], - "source": [ - "# @title Build prompt and settings objects\n", - "\n", - "# @markdown some advanced features demonstrated in commented-out code in this cell\n", - "\n", - "plot_prompt_weight_curves = True # @param {type: 'boolean'}\n", - "plot_settings_weight_curves = True # @param {type: 'boolean'}\n", - "\n", - "#################\n", - "\n", - "# Build Prompt objects\n", - "\n", - "prompts = [\n", - " Prompt(text, weight_schedule) \n", - " for (text, weight_schedule) in prompt_params\n", - "]\n", - "\n", - "# uncomment to loop the prompts\n", - "#for p in prompts:\n", - "# if len(p.weight.keyframes) > 1: # ignore negative prompts\n", - "# p.weight.loop=True \n", - "\n", - "# uncomment to loop prompts in \"bounce\" mode\n", - "#for p in prompts:\n", - "# if len(p.weight.keyframes) > 1:\n", - "# p.weight.bounce=True \n", - "\n", - "#################\n", - "\n", - "# Build Settings object\n", - "\n", - "g = parse_curvable_string(g)\n", - "sigma = parse_curvable_string(sigma)\n", - "h = parse_curvable_string(h)\n", - "gamma = parse_curvable_string(gamma)\n", - "alpha = parse_curvable_string(alpha)\n", - "tau = parse_curvable_string(tau)\n", - "steps = parse_curvable_string(refinement_steps)\n", - "\n", - "\n", - "curved_settings = ParameterGroup({\n", - " 'g':SmoothCurve(g),\n", - " 'sigma':SmoothCurve(sigma),\n", - " #'h':SmoothCurve(h),\n", - " \n", - " # more concise notation for flowers demo:\n", - " 'h':SmoothCurve({0:0.1, 30:0.1, 50:0.3}, bounce=True),\n", - " #'h':SmoothCurve({0:0.1, 30:0.1, 50:0.3, 70:0.1, 90:0.1}, loop=True),\n", - "\n", - " 'gamma':SmoothCurve(gamma),\n", - " 'alpha':SmoothCurve(alpha),\n", - " 'tau':SmoothCurve(tau),\n", - " 'steps':SmoothCurve(steps),\n", - "})\n", - "\n", - "\n", - "if plot_prompt_weight_curves:\n", - " for prompt in prompts:\n", - " prompt.weight.plot(n=n)\n", - " plt.title(\"prompt weight schedules\")\n", - " plt.show()\n", - "\n", - "\n", - "if plot_settings_weight_curves:\n", - " for name, curve in curved_settings.parameters.items():\n", - " curve.plot(n=n)\n", - " plt.title(name)\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "import keyframed.serialization\n", - "txt = keyframed.serialization.to_yaml(curved_settings, simplify=True)\n", - "\n", - "#print(txt)\n", - "\n", - "# sigma: 1.25\n", - "#\n", - "# becomes:\n", - "#\n", - "# sigma:\n", - "# curve:\n", - "# - - 0\n", - "# - 1.25\n", - "# - eased_lerp\n", - "#\n", - "# :\n", - "# curve:\n", - "# - - \n", - "# - \n", - "# - \n", - "# - \n", - "# - - \n", - "# - \n", - "# - - \n", - "# - \n", - "\n", - "with open(outdir / 'settings.yaml', 'w') as f:\n", - " f.write(txt)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# load settings from disk\n", - "\n", - "load_settings_from_disk = True # @param {type:'boolean'}\n", - "\n", - "if load_settings_from_disk:\n", - " with open(outdir / 'settings.yaml', 'r') as f:\n", - " curved_settings = keyframed.serialization.from_yaml(f.read())\n", - "\n", - "curved_settings.to_dict(simplify=True)['parameters']\n", - "#curved_settings.plot()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "i-_u1Q0wRqMb", - "tags": [] - }, - "outputs": [], - "source": [ - "#@title Generate Animation Frames\n", - "\n", - "resume = True # @param {type:'boolean'}\n", - "archive_old_work = False # @param {type:'boolean'}\n", - "\n", - "# -1 = most recent frame\n", - "resume_from = -1 # @param {type:'number'}\n", - "\n", - "# @markdown optional debugging plots\n", - "plot_prompt_weights = True # @param {type:'boolean'}\n", - "plot_h = False # @param {type:'boolean'}\n", - "plot_g = False # @param {type:'boolean'}\n", - "plot_sigma = False # @param {type:'boolean'}\n", - "plot_gamma = False # @param {type:'boolean'}\n", - "plot_alpha = False # @param {type:'boolean'}\n", - "plot_tau = False # @param {type:'boolean'}\n", - "\n", - "################\n", - "\n", - "_seed = seed\n", - "if seed < 0: \n", - " _seed = random.randrange(0, 4294967295)\n", - "print(f\"using seed: {_seed}\")\n", - "torch.manual_seed(_seed)\n", - "\n", - "stuff_to_plot = []\n", - "if plot_prompt_weights:\n", - " stuff_to_plot.append('prompts')\n", - "if plot_h:\n", - " stuff_to_plot.append('h')\n", - "if plot_g:\n", - " stuff_to_plot.append('g')\n", - "if plot_sigma:\n", - " stuff_to_plot.append('sigma')\n", - "if plot_gamma:\n", - " stuff_to_plot.append('gamma')\n", - "if plot_alpha:\n", - " stuff_to_plot.append('alpha')\n", - "if plot_tau:\n", - " stuff_to_plot.append('tau')\n", - "\n", - "if not resume:\n", - " if archive_old_work:\n", - " archive_dir = outdir.parent / 'archive' / str(int(time.time()))\n", - " archive_dir.mkdir(parents=True, exist_ok=True)\n", - " print(f\"Archiving contents of /frames, moving to: {archive_dir}\")\n", - " else:\n", - " print(\"Old contents of /frames being deleted. This can be prevented in the future by setting either 'resume' or 'archive_old_work' to True.\")\n", - " for p in outdir.glob(f'*'):\n", - " if archive_old_work:\n", - " target = archive_dir / p.name\n", - " p.rename(target)\n", - " else:\n", - " p.unlink()\n", - " for p in Path('debug_frames').glob(f'*'):\n", - " p.unlink()\n", - "\n", - "sample_mcmc_klmc2(\n", - " sd_model=sd_model,\n", - " init_image=init_image,\n", - " height=height,\n", - " width=width,\n", - " n=n,\n", - " hvp_method=hvp_method,\n", - " prompts=prompts,\n", - " settings=curved_settings,\n", - " resume=resume,\n", - " resume_from=resume_from,\n", - " img_init_steps=img_init_steps,\n", - " stuff_to_plot=stuff_to_plot,\n", - " checkpoint_every=checkpoint_every,\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "DjwY7XrooLX_" - }, - "outputs": [], - "source": [ - "#@title Make the video\n", - "\n", - "if 'width' not in locals():\n", - " width = height = 512\n", - "\n", - "\n", - "# @markdown If your video is larger than a few MB, attempting to embed it will probably crash\n", - "# @markdown the session. If this happens, view the generated video after downloading it first.\n", - "embed_video = True # @param {type:'boolean'}\n", - "download_video = False # @param {type:'boolean'}\n", - "\n", - "upscale_video = False # @param {type:'boolean'}\n", - "\n", - "\n", - "outdir_str = str(outdir)\n", - "\n", - "fps = 14 # @param {type:\"integer\"}\n", - "out_fname = \"out.mp4\" # @param {type: \"string\"}\n", - "\n", - "out_fullpath = str( outdir / out_fname )\n", - "print(f\"Video will be saved to: {out_fullpath}\")\n", - "\n", - "compile_video_cmd = f\"ffmpeg -y -r {fps} -i 'out_%*.png' -crf 15 -preset veryslow -pix_fmt yuv420p \" # {out_fname}\"\n", - "if upscale_video:\n", - " compile_video_cmd += f\"-vf scale={2*width}x{2*height}:flags=lanczos \"\n", - "compile_video_cmd += f\"{out_fname}\"\n", - "\n", - "print('\\nMaking the video...\\n')\n", - "!cd {outdir_str}; {compile_video_cmd}\n", - "\n", - "\n", - "debug=True\n", - "if debug:\n", - " #outdir_str = \"debug_frames\"\n", - " print(\"\\nMaking debug video...\")\n", - " #!cd debug_frames; ffmpeg -y -r {fps} -i 'prompts_out_%*.png' -crf 15 -preset veryslow -pix_fmt yuv420p debug_out.mp4\n", - " !cd {debug_dir}; ffmpeg -y -r {fps} -i 'debug_out_%*.png' -crf 15 -preset veryslow -pix_fmt yuv420p debug_out.mp4\n", - "\n", - "if embed_video:\n", - " print('\\nThe video:')\n", - " show_video(out_fullpath)\n", - " if debug:\n", - " show_video(debug_dir / \"debug_out.mp4\")\n", - "\n", - "if download_video and probably_using_colab:\n", - " from google.colab import files\n", - " files.download(out_fullpath)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "rK_GlP_7WJiu", - "tags": [] - }, - "outputs": [], - "source": [ - "#@title Licensed under the MIT License { display-mode: \"form\" }\n", - "\n", - "# Copyright (c) 2022 Katherine Crowson \n", - "# Copyright (c) 2023 David Marx \n", - "# Copyright (c) 2022 deforum and contributors\n", - "\n", - "# Permission is hereby granted, free of charge, to any person obtaining a copy\n", - "# of this software and associated documentation files (the \"Software\"), to deal\n", - "# in the Software without restriction, including without limitation the rights\n", - "# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", - "# copies of the Software, and to permit persons to whom the Software is\n", - "# furnished to do so, subject to the following conditions:\n", - "\n", - "# The above copyright notice and this permission notice shall be included in\n", - "# all copies or substantial portions of the Software.\n", - "\n", - "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", - "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", - "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", - "# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", - "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", - "# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n", - "# THE SOFTWARE." - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "private_outputs": true, - "provenance": [] - }, - "gpuClass": "premium", - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - }, - "vscode": { - "interpreter": { - "hash": "ff1624fd81a21ea709585fb1fdce5419f857f6a9e76cb1632f1b8b574978f9ee" - } - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file From e3fa140ee7c4557a97262cd6eb2f40f692d5284d Mon Sep 17 00:00:00 2001 From: David Marx Date: Thu, 2 Mar 2023 10:46:34 -0800 Subject: [PATCH 08/14] added prompt persistence --- Stable_Diffusion_KLMC2_Animation.ipynb | 3273 ++++++++++++------------ 1 file changed, 1651 insertions(+), 1622 deletions(-) diff --git a/Stable_Diffusion_KLMC2_Animation.ipynb b/Stable_Diffusion_KLMC2_Animation.ipynb index 8bec79d..a7c430f 100644 --- a/Stable_Diffusion_KLMC2_Animation.ipynb +++ b/Stable_Diffusion_KLMC2_Animation.ipynb @@ -1,1624 +1,1653 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "15BNHICpOOXg" - }, - "source": [ - "# Stable Diffusion KLMC2 Animation\n", - "\n", - "
\n", - "\n", - "
\n", - "\n", - "\n", - "Notebook by [Katherine Crowson](https://twitter.com/RiversHaveWings), modified by [David Marx](https://twitter.com/DigThatData).\n", - "\n", - "Sponsored by [StabilityAI](https://twitter.com/stabilityai)\n", - "\n", - "Generate animations with [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) 1.4, using the [KLMC2 discretization of underdamped Langevin dynamics](https://arxiv.org/abs/1807.09382). The notebook is largely inspired by [Ajay Jain](https://twitter.com/ajayj_) and [Ben Poole](https://twitter.com/poolio)'s paper [Journey to the BAOAB-limit](https://www.ajayjain.net/journey)—thank you so much for it!\n", - "\n", - "---\n", - "\n", - "## Modifications Provenance\n", - "\n", - "Original notebook URL - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1m8ovBpO2QilE2o4O-p2PONSwqGn4_x2G)\n", - "\n", - "Features and QOL Modifications by [David Marx](https://twitter.com/DigThatData) - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmarx/notebooks/blob/main/Stable_Diffusion_KLMC2_Animation.ipynb)\n", - "\n", - "Last updated date (effectively, notebook version): 2022-02-20\n", - "\n", - "* Keyframed prompts and settings\n", - "* Multiprompt conditioning w independent prompt schedules\n", - "* Set seed for deterministic output\n", - "* Mount Google Drive\n", - "* Faster Setup\n", - "* Init image\n", - "* Alt-checkpoint loading consistent w/deforum\n", - "* Set output filename\n", - "* Fancy GPU info\n", - "* Video embed optional\n", - "* ~~Cheaper default runtime~~ torn about this\n", - "* Local setup\n", - "* New VAE option\n", - "* Smooth interpolation for settings curves\n", - "* Settings curves specified via simple DSL\n", - "* Exposed `refinement_steps` parameter\n", - "* Custom output resolution\n", - "* Optional video upscale\n", - "* Optional resume, user can specify resumption frame (auto-checkpoints every 10 frames)\n", - "* Optional archival\n", - "* Assorted refactoring\n", - "* Debugging plots and animations\n", - "\n", - "## Local Setup\n", - "\n", - "Download the repo containing this notebook and supplementary setup files.\n", - "\n", - "```\n", - "git clone https://github.com/dmarx/notebooks\n", - "cd notebooks\n", - "```\n", - "\n", - "Strongly recommend setting up and activating a virtual environment first. Here's one option that is built into python, windows users in particular might want to consider using anaconda as an alternative.\n", - "\n", - "```bash\n", - "python3 -m venv _venv\n", - "source _venv/bin/activate\n", - "pip install jupyter\n", - "```\n", - "\n", - "With this venv created, in the future you only need to run `source _venv/bin/activate` to activate it.\n", - "\n", - "You can now start a local jupyter instance from the terminal in which the virtual environment is activated by running the `jupyter` command, or alternatively select the new virtualenv as the python environment in your IDE of choice. When you run the notebook's setup cells, it should detect that local setup needs to be performed and modify its setup procedure appropriately.\n", - "\n", - "A common source of errors is user confusion between the python environment running the notebook and an intended virtual environment into which setup has already been performed. To validate that you are using the python environment you think you are, run the command `which python` (this locates the executable associated with the `python` command) both inside the notebook and in a terminal in which your venv is activated: the results should be identical.\n", - "\n", - "## Contact\n", - "\n", - "Report bugs or feature ideas here: https://github.com/dmarx/notebooks/issues" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "Ty3IOeXbLzvc", - "tags": [] - }, - "outputs": [], - "source": [ - "#@title Check GPU\n", - "#!nvidia-smi\n", - "\n", - "import pandas as pd\n", - "import subprocess\n", - "\n", - "def gpu_info():\n", - " outv = subprocess.run([\n", - " 'nvidia-smi',\n", - " # these lines concatenate into a single query string\n", - " '--query-gpu='\n", - " 'timestamp,'\n", - " 'name,'\n", - " 'utilization.gpu,'\n", - " 'utilization.memory,'\n", - " 'memory.used,'\n", - " 'memory.free,'\n", - " ,\n", - " '--format=csv'\n", - " ],\n", - " stdout=subprocess.PIPE).stdout.decode('utf-8')\n", - "\n", - " header, rec = outv.split('\\n')[:-1]\n", - " return pd.DataFrame({' '.join(k.strip().split('.')).capitalize():v for k,v in zip(header.split(','), rec.split(','))}, index=[0]).T\n", - "\n", - "gpu_info()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "kelHR9VM1-hg", - "tags": [] - }, - "outputs": [], - "source": [ - "#@title Install Dependencies\n", - "\n", - "# @markdown Your runtime will automatically restart after running this cell.\n", - "# @markdown You should only need to run this cell when setting up a new runtime. After future runtime restarts,\n", - "# @markdown you should be able to skip this cell.\n", - "\n", - "import warnings\n", - "\n", - "probably_using_colab = False\n", - "try:\n", - " import google\n", - " probably_using_colab = True\n", - "except ImportError:\n", - " warnings.warn(\"Unable to import `google`, assuming this means we're using a local runtime\")\n", - "\n", - "# @markdown Not recommended for colab users. This notebook is currently configured to only make this\n", - "# @markdown option available for local install.\n", - "use_xformers = False\n", - "\n", - "try:\n", - " import keyframed\n", - "except ImportError:\n", - " if probably_using_colab:\n", - " !pip install ftfy einops braceexpand requests transformers clip open_clip_torch omegaconf pytorch-lightning kornia k-diffusion ninja omegaconf\n", - " !pip install -U git+https://github.com/huggingface/huggingface_hub\n", - " !pip install napm keyframed\n", - " else:\n", - " !pip install -r klmc2/requirements.txt\n", - " if use_xformers:\n", - " !pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers\n", - "\n", - " exit() # restarts the runtime" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "fJZtXShcPXx5", - "tags": [] - }, - "outputs": [], - "source": [ - "# @markdown # Setup Workspace { display-mode: \"form\" }\n", - "\n", - "###################\n", - "# Setup Workspace #\n", - "###################\n", - "\n", - "import os\n", - "from pathlib import Path\n", - "import warnings\n", - "\n", - "probably_using_colab = False\n", - "try:\n", - " import google\n", - " if Path('/content').exists():\n", - " probably_using_colab = True\n", - " print(\"looks like we're in colab\")\n", - " else:\n", - " print(\"looks like we're not in colab\")\n", - "except ImportError:\n", - " warnings.warn(\"Unable to import `google`, assuming this means we're using a local runtime\")\n", - "\n", - "\n", - "mount_gdrive = True # @param {type:'boolean'}\n", - "\n", - "# defaults\n", - "outdir = Path('./frames')\n", - "if not os.environ.get('XDG_CACHE_HOME'):\n", - " os.environ['XDG_CACHE_HOME'] = str(Path('~/.cache').expanduser())\n", - "\n", - "if mount_gdrive and probably_using_colab:\n", - " from google.colab import drive\n", - " drive.mount('/content/drive')\n", - " Path('/content/drive/MyDrive/AI/models/.cache/').mkdir(parents=True, exist_ok=True) \n", - " os.environ['XDG_CACHE_HOME']='/content/drive/MyDrive/AI/models/.cache'\n", - " outdir = Path('/content/drive/MyDrive/AI/klmc2/frames/')\n", - "\n", - "# make sure the paths we need exist\n", - "outdir.mkdir(parents=True, exist_ok=True)\n", - "debug_dir = outdir.parent / 'debug_frames'\n", - "debug_dir.mkdir(parents=True, exist_ok=True)\n", - "\n", - "\n", - "os.environ['NAPM_PATH'] = str( Path(os.environ['XDG_CACHE_HOME']) / 'napm' )\n", - "Path(os.environ['NAPM_PATH']).mkdir(parents=True, exist_ok=True)\n", - "\n", - "\n", - "import napm\n", - "\n", - "url = 'https://github.com/Stability-AI/stablediffusion'\n", - "napm.pseudoinstall_git_repo(url, add_install_dir_to_path=True)\n", - "\n", - "\n", - "##### Moved from model loading cell\n", - "\n", - "if probably_using_colab:\n", - " models_path = \"/content/models\" #@param {type:\"string\"}\n", - "else:\n", - " models_path = os.environ['XDG_CACHE_HOME']\n", - "\n", - "if mount_gdrive and probably_using_colab:\n", - " models_path_gdrive = \"/content/drive/MyDrive/AI/models\" #@param {type:\"string\"}\n", - " models_path = models_path_gdrive\n", - "\n", - "if not Path(models_path).exists():\n", - " Path(models_path).mkdir(parents=True, exist_ok=True)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "y2jXKIf2ZkT8", - "tags": [] - }, - "outputs": [], - "source": [ - "# @markdown # Imports and Definitions { display-mode: \"form\" }\n", - "\n", - "###########\n", - "# imports #\n", - "###########\n", - "\n", - "# importing napm puts the stable diffusion repo on the PATH, which is where `ldm` imports from\n", - "import napm\n", - "from ldm.util import instantiate_from_config\n", - "\n", - "from base64 import b64encode\n", - "from collections import defaultdict\n", - "from concurrent import futures\n", - "import math\n", - "from pathlib import Path\n", - "import random\n", - "import re\n", - "import requests\n", - "from requests.exceptions import HTTPError\n", - "import sys\n", - "import time\n", - "from urllib.parse import urlparse\n", - "import warnings\n", - "\n", - "import functorch\n", - "import huggingface_hub\n", - "from IPython.display import display, Video, HTML\n", - "import k_diffusion as K\n", - "from keyframed import Curve, ParameterGroup, SmoothCurve\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np \n", - "from omegaconf import OmegaConf\n", - "import PIL\n", - "from PIL import Image\n", - "import torch\n", - "from torch import nn\n", - "from tqdm.auto import tqdm, trange\n", - "\n", - "from loguru import logger\n", - "import time\n", - "from natsort import natsorted\n", - "\n", - "\n", - "#########################\n", - "# Define useful globals #\n", - "#########################\n", - "\n", - "cpu = torch.device(\"cpu\")\n", - "device = torch.device(\"cuda\")\n", - "\n", - "\n", - "##############################\n", - "# Define necessary functions #\n", - "##############################\n", - " \n", - "import matplotlib.image\n", - "import numpy as np\n", - " \n", - "def get_latest_frame(i=None, latest_frame_fpath=None):\n", - " latest_frame = None\n", - " if latest_frame_fpath is not None:\n", - " latest_frame = latest_frame_fpath\n", - " \n", - " if (latest_frame is None) and (i is None):\n", - " frames = Path('frames').glob(\"*.png\")\n", - " #logger.debug(f\"frames: {len(frames)}\")\n", - " latest_frame = natsort.sort(frames)[-1]\n", - " i = re.findall('out_([0-9]+).png', latest_frame.name)\n", - " else:\n", - " latest_frame = Path('frames') / f\"out_{i:05}.png\"\n", - " logger.debug(f'-latest_frame: {latest_frame}')\n", - " #return Image.open(latest_frame)\n", - " img = matplotlib.image.imread(latest_frame)\n", - " return np.flip(img, axis=0) # up/down\n", - "\n", - "def plot_prompts(prompts=None, n=1000, settings=None, **kargs):\n", - " if prompts is not None:\n", - " for prompt in prompts:\n", - " prompt.weight.plot(n=n, **kargs)\n", - "\n", - "def plot_param(param, settings=None, prompts=None, n=1000, **kargs):\n", - " settings.parameters[param].plot(n=n, **kargs)\n", - " \n", - "# move imports up\n", - "import base64\n", - "from io import BytesIO\n", - "from functools import partial\n", - " \n", - "@logger.catch\n", - "def write_debug_frame_at_(\n", - " i=None,\n", - " n=300, \n", - " prompts=None, \n", - " stuff_to_plot=['prompts'], \n", - " latest_frame_fpath=None,\n", - " pil_image=None,\n", - " settings=None,\n", - "):\n", - " plotting_funcs = {\n", - " 'prompts': plot_prompts,\n", - " 'g': partial(plot_param, param='g'),\n", - " 'h': partial(plot_param, param='h'),\n", - " 'sigma': partial(plot_param, param='sigma'),\n", - " 'gamma': partial(plot_param, param='gamma'),\n", - " 'alpha': partial(plot_param, param='alpha'),\n", - " 'tau': partial(plot_param, param='tau'),\n", - " }\n", - " \n", - " # i feel like this line of code justifies the silly variable name\n", - " if not stuff_to_plot:\n", - " return\n", - " \n", - " #stuff_to_plot = []\n", - " \n", - " test_im = pil_image\n", - " if pil_image is None:\n", - " test_im = get_latest_frame(i, latest_frame_fpath)\n", - "\n", - " fig = plt.figure()\n", - " #axsRight = fig.subplots(3, 1, sharex=True)\n", - " #ax = axsRight[0]\n", - " ax_objs = fig.subplots(len(stuff_to_plot), 1, sharex=True)\n", - " \n", - " #width, height = test_im.size\n", - " height, width = test_im.size\n", - " fig.set_size_inches(height/fig.dpi, width/fig.dpi )\n", - " \n", - " buffer = BytesIO()\n", - " for j, category in enumerate(stuff_to_plot):\n", - " ax = ax_objs\n", - " if len(stuff_to_plot) > 1:\n", - " ax = ax_objs[j]\n", - " plt.sca(ax)\n", - " plt.tight_layout()\n", - " plt.axis('off')\n", - " \n", - " plotting_funcs[category](prompts=prompts, settings=settings, n=n, zorder=1)\n", - " plt.axvline(x=i)\n", - " \n", - " \n", - "\n", - " #plt.margins(0)\n", - " fig.savefig(buffer, transparent=True) \n", - " plt.close()\n", - "\n", - " buffer.seek(0)\n", - " plot_pil = Image.open(buffer)\n", - " #buffer.close() # throws error here\n", - "\n", - " #debug_im_path = Path('debug_frames') / f\"{category}_out_{i:05}.png\"\n", - " #debug_im_path = Path('debug_frames') / f\"debug_out_{i:05}.png\"\n", - " debug_im_path = debug_dir / f\"debug_out_{i:05}.png\"\n", - " test_im = test_im.convert('RGBA')\n", - " test_im.paste(plot_pil, (0,0), plot_pil)\n", - " test_im.save(debug_im_path)\n", - " #display(test_im) # maybe?\n", - " buffer.close() # I guess?\n", - " \n", - " return test_im, plot_pil\n", - "\n", - "##############################\n", - "\n", - "class Prompt:\n", - " def __init__(\n", - " self,\n", - " text,\n", - " weight_schedule,\n", - " ):\n", - " c = sd_model.get_learned_conditioning([text])\n", - " self.text=text\n", - " self.encoded=c\n", - " self.weight = SmoothCurve(weight_schedule)\n", - "\n", - "\n", - "def handle_chigozienri_curve_format(value_string):\n", - " if value_string.startswith('(') and value_string.endswith(')'):\n", - " value_string = value_string[1:-1]\n", - " return value_string\n", - "\n", - "def parse_curve_string(txt, f=float):\n", - " schedule = {}\n", - " for tokens in txt.split(','):\n", - " k,v = tokens.split(':')\n", - " v = handle_chigozienri_curve_format(v)\n", - " schedule[int(k)] = f(v)\n", - " return schedule\n", - "\n", - "def parse_curvable_string(param, is_int=False):\n", - " if isinstance(param, dict):\n", - " return param\n", - " f = float\n", - " if is_int:\n", - " f = int\n", - " try:\n", - " return f(param)\n", - " except ValueError:\n", - " return parse_curve_string(txt=param, f=f)\n", - "\n", - "##################\n", - "\n", - "def show_video(video_path, video_width=512):\n", - " return display(Video(video_path, width=video_width))\n", - "\n", - "if probably_using_colab:\n", - " def show_video(video_path, video_width=512):\n", - " video_file = open(video_path, \"r+b\").read()\n", - " video_url = f\"data:video/mp4;base64,{b64encode(video_file).decode()}\"\n", - " return display(HTML(f\"\"\"\"\"\"))\n", - "\n", - "##################\n", - "\n", - "class NormalizingCFGDenoiser(nn.Module):\n", - " def __init__(self, model, g):\n", - " super().__init__()\n", - " self.inner_model = model\n", - " self.g = g\n", - " self.eps_norms = defaultdict(lambda: (0, 0))\n", - "\n", - " def mean_sq(self, x):\n", - " return x.pow(2).flatten(1).mean(1)\n", - "\n", - " @torch.no_grad()\n", - " def update_eps_norm(self, eps, sigma):\n", - " sigma = sigma[0].item()\n", - " eps_norm = self.mean_sq(eps).mean()\n", - " eps_norm_avg, count = self.eps_norms[sigma]\n", - " eps_norm_avg = eps_norm_avg * count / (count + 1) + eps_norm / (count + 1)\n", - " self.eps_norms[sigma] = (eps_norm_avg, count + 1)\n", - " return eps_norm_avg\n", - "\n", - " def forward(self, x, sigma, uncond, cond, g):\n", - " x_in = torch.cat([x] * 2)\n", - " sigma_in = torch.cat([sigma] * 2)\n", - " cond_in = torch.cat([uncond, cond])\n", - "\n", - " denoised = self.inner_model(x_in, sigma_in, cond=cond_in)\n", - " eps = K.sampling.to_d(x_in, sigma_in, denoised)\n", - " eps_uc, eps_c = eps.chunk(2)\n", - " eps_norm = self.update_eps_norm(eps, sigma).sqrt()\n", - " c = eps_c - eps_uc\n", - " cond_scale = g * eps_norm / self.mean_sq(c).sqrt()\n", - " eps_final = eps_uc + c * K.utils.append_dims(cond_scale, x.ndim)\n", - " return x - eps_final * K.utils.append_dims(sigma, eps.ndim)\n", - "\n", - "#########################\n", - "\n", - "def write_klmc2_state(**state):\n", - " st = time.time()\n", - " obj = {}\n", - " for k,v in state.items():\n", - " try:\n", - " v = v.clone().detach().cpu()\n", - " except AttributeError:\n", - " # if it doesn't have a detach method, we don't need to worry about any preprocessing\n", - " pass\n", - " obj[k] = v\n", - "\n", - " checkpoint_fpath = Path(outdir) / f\"klmc2_state_{state.get('i',0):05}.ckpt\"\n", - " with open(checkpoint_fpath, 'wb') as f:\n", - " torch.save(obj, f=f)\n", - " et = time.time()\n", - " #logger.debug(f\"checkpointing: {et-st}\")\n", - " # to do: move to callback? thread executor, anyway\n", - "\n", - "def read_klmc2_state(root=outdir, latest_frame=-1):\n", - " state = {}\n", - " checkpoints = [str(p) for p in Path(root).glob(\"*.ckpt\")]\n", - " if not checkpoints:\n", - " return None\n", - " checkpoints = natsorted(checkpoints)\n", - " if latest_frame < 0:\n", - " ckpt_fpath = checkpoints[-1]\n", - " else:\n", - " for fname in checkpoints:\n", - " frame_id = re.findall(r'([0-9]+).ckpt', fname)[0]\n", - " if int(frame_id) <= latest_frame:\n", - " ckpt_fpath = fname\n", - " else:\n", - " break\n", - " logger.debug(ckpt_fpath)\n", - " with open(ckpt_fpath,'rb') as f:\n", - " state = torch.load(f=f,map_location='cuda')\n", - " return state\n", - "\n", - "def load_init_image(init_image, height, width):\n", - " if not Path(init_image).exists():\n", - " raise FileNotFoundError(f\"Unable to locate init image from path: {init_image}\")\n", - " \n", - " \n", - " from PIL import Image\n", - " import numpy as np\n", - "\n", - " init_im_pil = Image.open(init_image)\n", - "\n", - " #x_pil = init_im_pil.resize([512,512])\n", - " x_pil = init_im_pil.resize([height,width])\n", - " x_np = np.array(x_pil.convert('RGB')).astype(np.float16) / 255.0\n", - " x = x_np[None].transpose(0, 3, 1, 2)\n", - " x = 2.*x - 1.\n", - " x = torch.from_numpy(x).to('cuda')\n", - " return x\n", - "\n", - "def save_image_fn(image, name, i, n, prompts=None, settings=None, stuff_to_plot=['prompts']):\n", - " pil_image = K.utils.to_pil_image(image)\n", - " if i % 10 == 0 or i == n - 1:\n", - " print(f'\\nIteration {i}/{n}:')\n", - " display(pil_image)\n", - " if i == n - 1:\n", - " print('\\nDone!')\n", - " pil_image.save(name)\n", - " if stuff_to_plot:\n", - " #logger.debug(stuff_to_plot)\n", - " #write_debug_frame_at_(i, prompts=prompts)\n", - " debug_frame, debug_plot = write_debug_frame_at_(i=i,n=n, prompts=prompts, settings=settings, stuff_to_plot=stuff_to_plot, pil_image=pil_image)\n", - " if i % 10 == 0 or i == n - 1:\n", - " #display(debug_frame)\n", - " display(debug_plot)\n", - "\n", - "###############################\n", - "\n", - "@torch.no_grad()\n", - "def sample_mcmc_klmc2(\n", - " sd_model, \n", - " init_image,\n", - " height:int,\n", - " width:int,\n", - " n:int, \n", - " hvp_method:str='reverse', \n", - " prompts:list=None,\n", - " settings:ParameterGroup=None,\n", - " resume:bool = False,\n", - " resume_from:int=-1,\n", - " img_init_steps:int=None,\n", - " stuff_to_plot:list=None,\n", - " checkpoint_every:int=10,\n", - "):\n", - "\n", - " if stuff_to_plot is None:\n", - " stuff_to_plot = ['prompts','h']\n", - " \n", - " torch.cuda.empty_cache()\n", - "\n", - " wrappers = {'eps': K.external.CompVisDenoiser, 'v': K.external.CompVisVDenoiser}\n", - " g = settings[0]['g']\n", - "\n", - " model_wrap = wrappers[sd_model.parameterization](sd_model)\n", - " model_wrap_cfg = NormalizingCFGDenoiser(model_wrap, g)\n", - " sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()\n", - " model = model_wrap_cfg\n", - "\n", - " uc = sd_model.get_learned_conditioning([''])\n", - " extra_args = {\n", - " 'uncond': uc, \n", - " #'cond': prompts[0].encoded,\n", - " 'g': settings[0]['g']\n", - " }\n", - "\n", - " sigma = settings[0]['sigma']\n", - "\n", - " with torch.cuda.amp.autocast(), futures.ThreadPoolExecutor() as ex:\n", - " def callback(info):\n", - " i = info['i']\n", - " rgb = sd_model.decode_first_stage(info['denoised'] )\n", - " ex.submit(save_image_fn, image=rgb, name=(outdir / f'out_{i:05}.png'), i=i, n=n, prompts=prompts, settings=settings, stuff_to_plot=stuff_to_plot)\n", - "\n", - " # Initialize the chain\n", - " print('Initializing the chain...')\n", - "\n", - " # to do: if resuming, generating this init image is unnecessary\n", - " x = None\n", - " if init_image:\n", - " print(\"loading init image\")\n", - " x = load_init_image(init_image, height, width)\n", - " # convert RGB to latent\n", - " x = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(x))\n", - " print(\"init image loaded.\")\n", - "\n", - " print('Actually doing the sampling...')\n", - "\n", - " i_resume=0\n", - " v = None\n", - " if resume:\n", - " state = read_klmc2_state(latest_frame=resume_from)\n", - " if state:\n", - " x, v, i_resume = state['x'], state['v'], state['i']\n", - " # to do: resumption of settings\n", - " settings_i = state['settings_i']\n", - " settings[i]['h'] = settings_i['h']\n", - " settings[i]['gamma'] = settings_i['gamma']\n", - " settings[i]['alpha'] = settings_i['alpha']\n", - " settings[i]['tau'] = settings_i['tau']\n", - " settings[i]['g'] = settings_i['g']\n", - " settings[i]['sigma'] = settings_i['sigma']\n", - " settings[i]['steps'] = settings_i['steps']\n", - " \n", - " # to do: use multicond for init image\n", - " # we want this test after resumption if resuming\n", - " if x is None:\n", - " print(\"No init image provided, generating a random init image\")\n", - " extra_args['cond'] = prompts[0].encoded\n", - " h=height//8\n", - " w=width//8\n", - " x = torch.randn([1, 4, h, w], device=device) * sigma_max\n", - " sigmas_pre = K.sampling.get_sigmas_karras(img_init_steps, sigma, sigma_max, device=x.device)[:-1]\n", - " x = K.sampling.sample_dpmpp_sde(model_wrap_cfg, x, sigmas_pre, extra_args=extra_args)\n", - "\n", - " # if not resuming, randomly initialize momentum\n", - " # this needs to be *after* generating X if we're going to...\n", - " if v is None:\n", - " v = torch.randn_like(x) * sigma\n", - "\n", - " # main sampling loop\n", - " for i in trange(n):\n", - " # fast-forward loop to resumption index\n", - " if resume and i < i_resume:\n", - " continue\n", - " # if resume and (i == i_resume):\n", - " # # should these values be written into settings[i]?\n", - " # h = settings_i['h']\n", - " # gamma = settings_i['gamma']\n", - " # alpha = settings_i['alpha']\n", - " # tau = settings_i['tau']\n", - " # g = settings_i['g']\n", - " # sigma = settings_i['sigma']\n", - " # steps = settings_i['steps']\n", - " # else:\n", - " h = settings[i]['h']\n", - " gamma = settings[i]['gamma']\n", - " alpha = settings[i]['alpha']\n", - " tau = settings[i]['tau']\n", - " g = settings[i]['g']\n", - " sigma = settings[i]['sigma']\n", - " steps = settings[i]['steps']\n", - "\n", - " h = torch.tensor(h, device=x.device)\n", - " gamma = torch.tensor(gamma, device=x.device)\n", - " alpha = torch.tensor(alpha, device=x.device)\n", - " tau = torch.tensor(tau, device=x.device)\n", - " sigma = torch.tensor(sigma, device=x.device)\n", - " steps = int(steps)\n", - " \n", - " sigmas = K.sampling.get_sigmas_karras(steps, sigma_min, sigma.item(), device=x.device)[:-1]\n", - "\n", - " x, v, grad = klmc2_step(\n", - " model,\n", - " prompts,\n", - " x,\n", - " v,\n", - " h,\n", - " gamma,\n", - " alpha,\n", - " tau,\n", - " g,\n", - " sigma,\n", - " sigmas,\n", - " steps,\n", - " hvp_method,\n", - " i,\n", - " callback,\n", - " extra_args,\n", - " )\n", - "\n", - " save_checkpoint = (i % checkpoint_every) == 0\n", - " if save_checkpoint:\n", - " settings_i = settings[i]\n", - " ex.submit(write_klmc2_state, v=v, x=x, i=i, settings_i=settings_i)\n", - " logger.debug(settings[i])\n", - "\n", - "\n", - "def hvp_fn_forward_functorch(model, x, sigma, v, alpha, extra_args):\n", - " def grad_fn(x, sigma):\n", - " s_in = x.new_ones([x.shape[0]])\n", - " denoised = model(x, sigma * s_in, **extra_args)\n", - " return (x - denoised) + alpha * x\n", - " jvp_fn = lambda v: functorch.jvp(grad_fn, (x, sigma), (v, torch.zeros_like(sigma)))\n", - " grad, jvp_out = functorch.vmap(jvp_fn)(v)\n", - " return grad[0], jvp_out\n", - "\n", - "def hvp_fn_reverse(model, x, sigma, v, alpha, extra_args):\n", - " def grad_fn(x, sigma):\n", - " s_in = x.new_ones([x.shape[0]])\n", - " denoised = model(x, sigma * s_in, **extra_args)\n", - " return (x - denoised) + alpha * x\n", - " vjps = []\n", - " with torch.enable_grad():\n", - " x_ = x.clone().requires_grad_()\n", - " grad = grad_fn(x_, sigma)\n", - " for k, item in enumerate(v):\n", - " vjp_out = torch.autograd.grad(grad, x_, item, retain_graph=k < len(v) - 1)[0]\n", - " vjps.append(vjp_out)\n", - " return grad, torch.stack(vjps)\n", - "\n", - "def hvp_fn_zero(model, x, sigma, v, alpha, extra_args):\n", - " def grad_fn(x, sigma):\n", - " s_in = x.new_ones([x.shape[0]])\n", - " denoised = model(x, sigma * s_in, **extra_args)\n", - " return (x - denoised) + alpha * x\n", - " return grad_fn(x, sigma), torch.zeros_like(v)\n", - "\n", - "def hvp_fn_fake(model, x, sigma, v, alpha, extra_args):\n", - " def grad_fn(x, sigma):\n", - " s_in = x.new_ones([x.shape[0]])\n", - " denoised = model(x, sigma * s_in, **extra_args)\n", - " return (x - denoised) + alpha * x\n", - " return grad_fn(x, sigma), (1 + alpha) * v\n", - "\n", - "\n", - "def multicond_hvp(model, x, sigma, v, alpha, extra_args, prompts, hvp_fn, i):\n", - "\n", - " # loop over prompts and aggregate gradients for multicond\n", - " grad = torch.zeros_like(x)\n", - " h2_v = torch.zeros_like(x)\n", - " h2_noise_v2 = torch.zeros_like(x)\n", - " h2_noise_x2 = torch.zeros_like(x)\n", - " wt_norm = 0\n", - " for prompt in prompts:\n", - " wt = prompt.weight[i]\n", - " if wt == 0:\n", - " continue\n", - " wt_norm += wt\n", - " wt = torch.tensor(wt, device=x.device)\n", - " extra_args['cond'] = prompt.encoded\n", - "\n", - " # Estimate gradient and hessian\n", - " grad_, (h2_v_, h2_noise_v2_, h2_noise_x2_) = hvp_fn(\n", - " model=model,\n", - " x=x, \n", - " sigma=sigma, \n", - " v=v,\n", - " alpha=alpha,\n", - " extra_args=extra_args,\n", - " )\n", - "\n", - " grad = grad + grad_ * wt \n", - " h2_v = h2_v + h2_v_ * wt\n", - " h2_noise_v2 = h2_noise_v2 + h2_noise_v2_ * wt\n", - " h2_noise_x2 = h2_noise_x2 + h2_noise_x2_ * wt\n", - "\n", - " # Normalize gradient to magnitude it'd have if just single prompt w/ wt=1.\n", - " # simplifies multicond w/o deep frying image or adding hyperparams\n", - " grad = grad / wt_norm \n", - " h2_v = h2_v / wt_norm\n", - " h2_noise_v2 = h2_noise_v2 / wt_norm\n", - " h2_noise_x2 = h2_noise_x2 / wt_norm\n", - "\n", - " return grad, h2_v, h2_noise_v2, h2_noise_v2, h2_noise_x2\n", - "\n", - "\n", - "\n", - "\n", - "def klmc2_step(\n", - " model,\n", - " prompts,\n", - " x,\n", - " v,\n", - " h,\n", - " gamma,\n", - " alpha,\n", - " tau,\n", - " g,\n", - " sigma,\n", - " sigmas,\n", - " steps,\n", - " hvp_method,\n", - " i,\n", - " callback,\n", - " extra_args,\n", - " ):\n", - "\n", - " #s_in = x.new_ones([x.shape[0]])\n", - "\n", - " # Model helper functions\n", - "\n", - " hvp_fns = {'forward-functorch': hvp_fn_forward_functorch,\n", - " 'reverse': hvp_fn_reverse,\n", - " 'zero': hvp_fn_zero,\n", - " 'fake': hvp_fn_fake}\n", - "\n", - " hvp_fn = hvp_fns[hvp_method]\n", - "\n", - " # KLMC2 helper functions\n", - " def psi_0(gamma, t):\n", - " return torch.exp(-gamma * t)\n", - "\n", - " def psi_1(gamma, t):\n", - " return -torch.expm1(-gamma * t) / gamma\n", - "\n", - " def psi_2(gamma, t):\n", - " return (torch.expm1(-gamma * t) + gamma * t) / gamma ** 2\n", - "\n", - " def phi_2(gamma, t_):\n", - " t = t_.double()\n", - " out = (torch.exp(-gamma * t) * (torch.expm1(gamma * t) - gamma * t)) / gamma ** 2\n", - " return out.to(t_)\n", - "\n", - " def phi_3(gamma, t_):\n", - " t = t_.double()\n", - " out = (torch.exp(-gamma * t) * (2 + gamma * t + torch.exp(gamma * t) * (gamma * t - 2))) / gamma ** 3\n", - " return out.to(t_)\n", - "\n", - "\n", - " # Compute model outputs and sample noise\n", - " x_trapz = torch.linspace(0, h, 1001, device=x.device)\n", - " y_trapz = [fun(gamma, x_trapz) for fun in (psi_0, psi_1, phi_2, phi_3)]\n", - " noise_cov = torch.tensor([[torch.trapz(y_trapz[i] * y_trapz[j], x=x_trapz) for j in range(4)] for i in range(4)], device=x.device)\n", - " noise_v, noise_x, noise_v2, noise_x2 = torch.distributions.MultivariateNormal(x.new_zeros([4]), noise_cov).sample(x.shape).unbind(-1)\n", - "\n", - " extra_args['g']=g\n", - "\n", - " # compute derivatives, multicond wrapper loops over prompts and averages derivatives\n", - " grad, h2_v, h2_noise_v2, h2_noise_v2, h2_noise_x2 = multicond_hvp(\n", - " model=model, \n", - " x=x, \n", - " sigma=sigma, \n", - " v=torch.stack([v, noise_v2, noise_x2]), # need a \"dummy\" v for init image generation\n", - " alpha=alpha, \n", - " extra_args=extra_args, \n", - " prompts=prompts, \n", - " hvp_fn=hvp_fn,\n", - " i=i,\n", - " )\n", - "\n", - " # DPM-Solver++(2M) refinement steps\n", - " x_refine = x\n", - " use_dpm = True\n", - " old_denoised = None\n", - " for j in range(len(sigmas) - 1):\n", - " if j == 0:\n", - " denoised = x_refine - grad\n", - " else:\n", - " s_in = x.new_ones([x.shape[0]])\n", - " denoised = model(x_refine, sigmas[j] * s_in, **extra_args)\n", - " dt_ode = sigmas[j + 1] - sigmas[j]\n", - " if not use_dpm or old_denoised is None or sigmas[j + 1] == 0:\n", - " eps = K.sampling.to_d(x_refine, sigmas[j], denoised)\n", - " x_refine = x_refine + eps * dt_ode\n", - " else:\n", - " h_ode = sigmas[j].log() - sigmas[j + 1].log()\n", - " h_last = sigmas[j - 1].log() - sigmas[j].log()\n", - " fac = h_ode / (2 * h_last)\n", - " denoised_d = (1 + fac) * denoised - fac * old_denoised\n", - " eps = K.sampling.to_d(x_refine, sigmas[j], denoised_d)\n", - " x_refine = x_refine + eps * dt_ode\n", - " old_denoised = denoised\n", - " if callback is not None:\n", - " callback({'i': i, 'denoised': x_refine})\n", - "\n", - " # Update the chain\n", - " noise_std = (2 * gamma * tau * sigma ** 2).sqrt()\n", - " v_next = 0 + psi_0(gamma, h) * v - psi_1(gamma, h) * grad - phi_2(gamma, h) * h2_v + noise_std * (noise_v - h2_noise_v2)\n", - " x_next = x + psi_1(gamma, h) * v - psi_2(gamma, h) * grad - phi_3(gamma, h) * h2_v + noise_std * (noise_x - h2_noise_x2)\n", - " v, x = v_next, x_next\n", - "\n", - " return x, v, grad " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "yt3d1hww17ST", - "tags": [] - }, - "outputs": [], - "source": [ - "#@markdown **Select and Load Model**\n", - "\n", - "## TO DO:\n", - "## - if local, try to load model from ~/.cache/huggingface/diffusers\n", - "\n", - "# modified from:\n", - "# https://github.com/deforum/stable-diffusion/blob/main/Deforum_Stable_Diffusion.ipynb\n", - "\n", - "import napm\n", - "from ldm.util import instantiate_from_config\n", - "\n", - "\n", - "model_config = \"v1-inference.yaml\" #@param [\"custom\",\"v1-inference.yaml\"]\n", - "model_checkpoint = \"sd-v1-4.ckpt\" #@param [\"custom\",\"sd-v1-4-full-ema.ckpt\",\"sd-v1-4.ckpt\",\"sd-v1-3-full-ema.ckpt\",\"sd-v1-3.ckpt\",\"sd-v1-2-full-ema.ckpt\",\"sd-v1-2.ckpt\",\"sd-v1-1-full-ema.ckpt\",\"sd-v1-1.ckpt\", \"robo-diffusion-v1.ckpt\",\"waifu-diffusion-v1-3.ckpt\"]\n", - "if model_checkpoint == \"waifu-diffusion-v1-3.ckpt\":\n", - " model_checkpoint = \"model-epoch05-float16.ckpt\"\n", - "custom_config_path = \"\" #@param {type:\"string\"}\n", - "custom_checkpoint_path = \"\" #@param {type:\"string\"}\n", - "\n", - "half_precision = True # check\n", - "check_sha256 = False #@param {type:\"boolean\"}\n", - "\n", - "model_map = {\n", - " \"sd-v1-4-full-ema.ckpt\": {\n", - " 'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a',\n", - " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-4-full-ema.ckpt',\n", - " 'requires_login': True,\n", - " },\n", - " \"sd-v1-4.ckpt\": {\n", - " 'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556',\n", - " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt',\n", - " 'requires_login': True,\n", - " },\n", - " \"sd-v1-3-full-ema.ckpt\": {\n", - " 'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca',\n", - " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/blob/main/sd-v1-3-full-ema.ckpt',\n", - " 'requires_login': True,\n", - " },\n", - " \"sd-v1-3.ckpt\": {\n", - " 'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f',\n", - " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/resolve/main/sd-v1-3.ckpt',\n", - " 'requires_login': True,\n", - " },\n", - " \"sd-v1-2-full-ema.ckpt\": {\n", - " 'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a',\n", - " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-2-full-ema.ckpt',\n", - " 'requires_login': True,\n", - " },\n", - " \"sd-v1-2.ckpt\": {\n", - " 'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d',\n", - " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/resolve/main/sd-v1-2.ckpt',\n", - " 'requires_login': True,\n", - " },\n", - " \"sd-v1-1-full-ema.ckpt\": {\n", - " 'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829',\n", - " 'url':'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1-full-ema.ckpt',\n", - " 'requires_login': True,\n", - " },\n", - " \"sd-v1-1.ckpt\": {\n", - " 'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea',\n", - " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1.ckpt',\n", - " 'requires_login': True,\n", - " },\n", - " \"robo-diffusion-v1.ckpt\": {\n", - " 'sha256': '244dbe0dcb55c761bde9c2ac0e9b46cc9705ebfe5f1f3a7cc46251573ea14e16',\n", - " 'url': 'https://huggingface.co/nousr/robo-diffusion/resolve/main/models/robo-diffusion-v1.ckpt',\n", - " 'requires_login': False,\n", - " },\n", - " \"model-epoch05-float16.ckpt\": {\n", - " 'sha256': '26cf2a2e30095926bb9fd9de0c83f47adc0b442dbfdc3d667d43778e8b70bece',\n", - " 'url': 'https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/model-epoch05-float16.ckpt',\n", - " 'requires_login': False,\n", - " },\n", - "}\n", - "\n", - "# config path\n", - "ckpt_config_path = custom_config_path if model_config == \"custom\" else os.path.join(models_path, model_config)\n", - "if os.path.exists(ckpt_config_path):\n", - " print(f\"{ckpt_config_path} exists\")\n", - "else:\n", - " #ckpt_config_path = \"./stable-diffusion/configs/stable-diffusion/v1-inference.yaml\"\n", - " ckpt_config_path = \"./v1-inference.yaml\"\n", - " if not Path(ckpt_config_path).exists():\n", - " !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml\n", - " \n", - "print(f\"Using config: {ckpt_config_path}\")\n", - "\n", - "# checkpoint path or download\n", - "ckpt_path = custom_checkpoint_path if model_checkpoint == \"custom\" else os.path.join(models_path, model_checkpoint)\n", - "ckpt_valid = True\n", - "if os.path.exists(ckpt_path):\n", - " print(f\"{ckpt_path} exists\")\n", - "elif 'url' in model_map[model_checkpoint]:\n", - " url = model_map[model_checkpoint]['url']\n", - "\n", - " # CLI dialogue to authenticate download\n", - " if model_map[model_checkpoint]['requires_login']:\n", - " print(\"This model requires an authentication token\")\n", - " print(\"Please ensure you have accepted its terms of service before continuing.\")\n", - "\n", - " username = input(\"What is your huggingface username?:\")\n", - " token = input(\"What is your huggingface token?:\")\n", - "\n", - " _, path = url.split(\"https://\")\n", - "\n", - " url = f\"https://{username}:{token}@{path}\"\n", - "\n", - " # contact server for model\n", - " print(f\"Attempting to download {model_checkpoint}...this may take a while\")\n", - " ckpt_request = requests.get(url)\n", - " request_status = ckpt_request.status_code\n", - "\n", - " # inform user of errors\n", - " if request_status == 403:\n", - " raise ConnectionRefusedError(\"You have not accepted the license for this model.\")\n", - " elif request_status == 404:\n", - " raise ConnectionError(\"Could not make contact with server\")\n", - " elif request_status != 200:\n", - " raise ConnectionError(f\"Some other error has ocurred - response code: {request_status}\")\n", - "\n", - " # write to model path\n", - " with open(os.path.join(models_path, model_checkpoint), 'wb') as model_file:\n", - " model_file.write(ckpt_request.content)\n", - "else:\n", - " print(f\"Please download model checkpoint and place in {os.path.join(models_path, model_checkpoint)}\")\n", - " ckpt_valid = False\n", - "\n", - "if check_sha256 and model_checkpoint != \"custom\" and ckpt_valid:\n", - " import hashlib\n", - " print(\"\\n...checking sha256\")\n", - " with open(ckpt_path, \"rb\") as f:\n", - " bytes = f.read() \n", - " hash = hashlib.sha256(bytes).hexdigest()\n", - " del bytes\n", - " if model_map[model_checkpoint][\"sha256\"] == hash:\n", - " print(\"hash is correct\\n\")\n", - " else:\n", - " print(\"hash in not correct\\n\")\n", - " ckpt_valid = False\n", - "\n", - "if ckpt_valid:\n", - " print(f\"Using ckpt: {ckpt_path}\")\n", - "\n", - "def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True):\n", - " map_location = \"cuda\" #@param [\"cpu\", \"cuda\"]\n", - " print(f\"Loading model from {ckpt}\")\n", - " pl_sd = torch.load(ckpt, map_location=map_location)\n", - " if \"global_step\" in pl_sd:\n", - " print(f\"Global Step: {pl_sd['global_step']}\")\n", - " sd = pl_sd[\"state_dict\"]\n", - " model = instantiate_from_config(config.model)\n", - " m, u = model.load_state_dict(sd, strict=False)\n", - " if len(m) > 0 and verbose:\n", - " print(\"missing keys:\")\n", - " print(m)\n", - " if len(u) > 0 and verbose:\n", - " print(\"unexpected keys:\")\n", - " print(u)\n", - "\n", - " if half_precision:\n", - " model = model.half().to(device)\n", - " else:\n", - " model = model.to(device)\n", - " model.eval()\n", - " return model\n", - "\n", - "if ckpt_valid:\n", - " local_config = OmegaConf.load(f\"{ckpt_config_path}\")\n", - " model = load_model_from_config(local_config, f\"{ckpt_path}\", half_precision=half_precision)\n", - " device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", - " model = model.to(device)\n", - "\n", - " # Disable checkpointing as it is not compatible with the method\n", - " for module in model.modules():\n", - " if hasattr(module, 'checkpoint'):\n", - " module.checkpoint = False\n", - " if hasattr(module, 'use_checkpoint'):\n", - " module.use_checkpoint = False\n", - "\n", - " sd_model=model\n", - "\n", - "####################################################################\n", - "\n", - "use_new_vae = True #@param {type:\"boolean\"}\n", - "\n", - "if use_new_vae:\n", - "\n", - " # from kat's notebook again\n", - "\n", - " def download_from_huggingface(repo, filename):\n", - " while True:\n", - " try:\n", - " return huggingface_hub.hf_hub_download(repo, filename)\n", - " except HTTPError as e:\n", - " if e.response.status_code == 401:\n", - " # Need to log into huggingface api\n", - " huggingface_hub.interpreter_login()\n", - " continue\n", - " elif e.response.status_code == 403:\n", - " # Need to do the click through license thing\n", - " print(f'Go here and agree to the click through license on your account: https://huggingface.co/{repo}')\n", - " input('Hit enter when ready:')\n", - " continue\n", - " else:\n", - " raise e\n", - "\n", - " vae_840k_model_path = download_from_huggingface(\"stabilityai/sd-vae-ft-mse-original\", \"vae-ft-mse-840000-ema-pruned.ckpt\")\n", - "\n", - " def load_model_from_config_kc(config, ckpt):\n", - " print(f\"Loading model from {ckpt}\")\n", - " pl_sd = torch.load(ckpt, map_location=\"cpu\")\n", - " sd = pl_sd[\"state_dict\"]\n", - " config = OmegaConf.load(config)\n", - "\n", - " try:\n", - " config['model']['params']['lossconfig']['target'] = \"torch.nn.Identity\"\n", - " print('Patched VAE config.')\n", - " except KeyError:\n", - " pass\n", - "\n", - " model = instantiate_from_config(config.model)\n", - " m, u = model.load_state_dict(sd, strict=False)\n", - " model = model.to(cpu).eval().requires_grad_(False)\n", - " return model\n", - "\n", - " vaemodel_yaml_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/models/first_stage_models/kl-f8/config.yaml\"\n", - " vaemodel_yaml_fname = 'config_vae_kl-f8.yaml'\n", - " vaemodel_yaml_fname_git = \"latent-diffusion/models/first_stage_models/kl-f8/config.yaml\"\n", - " if Path(vaemodel_yaml_fname_git).exists():\n", - " vae_model = load_model_from_config_kc(vaemodel_yaml_fname_git, vae_840k_model_path).half().to(device)\n", - " else:\n", - " if not Path(vaemodel_yaml_fname).exists():\n", - " !wget {vaemodel_yaml_url} -O {vaemodel_yaml_fname}\n", - " vae_model = load_model_from_config_kc(vaemodel_yaml_fname, vae_840k_model_path).half().to(device)\n", - "\n", - " del sd_model.first_stage_model\n", - " sd_model.first_stage_model = vae_model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "ZljSF1ePnBl4", - "tags": [] - }, - "outputs": [], - "source": [ - "# @title Settings\n", - "\n", - "# @markdown The number of frames to sample:\n", - "n = 300 # @param {type:\"integer\"}\n", - "\n", - "# @markdown height and width must be multiples of 8 (e.g. 256, 512, 768, 1024)\n", - "height = 512 # @param {type:\"integer\"}\n", - "\n", - "width = 512 # @param {type:\"integer\"}\n", - "\n", - "\n", - "# @markdown If seed is negative, a random seed will be used\n", - "seed = -1 # @param {type:\"number\"}\n", - "\n", - "init_image = \"\" # @param {type:'string'}\n", - "\n", - "# @markdown ---\n", - "\n", - "# @markdown Settings below this line can be parameterized using keyframe syntax: `\"time:weight, time:weight, ...\". \n", - "# @markdown Over spans where values of weights change, intermediate values will be interpolated using an \"s\" shaped curve.\n", - "# @markdown If a value for keyframe 0 is not specified, it is presumed to be `0:0`.\n", - "\n", - "# @markdown The strength of the conditioning on the prompt:\n", - "g=\"0:0.1\" # @param {type:\"string\"}\n", - "\n", - "# @markdown The noise level to sample at\n", - "# @markdown Ramp up from a tiny sigma if using init image, e.g. `0:0.25, 100:2, ...`\n", - "# @markdown NB: Turning sigma *up* mid generation seems to work fine, but turning sigma *down* mid generation tends to \"deep fry\" the outputs\n", - "sigma = \"1.25\" # @param {type:\"string\"}\n", - "\n", - "# @markdown Step size (range 0 to 1):\n", - "h = \"0:0.1, 30:0.1, 50:0.3, 70:0.1, 120:0.1, 140:.3, 160:.1, 210:.1, 230:.3, 250:.1\" # @param {type:\"string\"}\n", - "\n", - "# @markdown Friction (2 is critically damped, lower -> smoother animation):\n", - "gamma = \"1.1\" # @param {type:\"string\"}\n", - "\n", - "# @markdown Quadratic penalty (\"weight decay\") strength:\n", - "alpha = \"0.005\" # @param {type:\"string\"}\n", - "\n", - "# @markdown Temperature (adjustment to the amount of noise added per step):\n", - "tau = \"1.0\" # @param {type:\"string\"}\n", - "\n", - "# @markdown Denoising refinement steps:\n", - "refinement_steps = \"6\" # @param {type:\"string\"}\n", - "\n", - "# @markdown If an init image is not provided, this is how many steps will be used when generating an initial state:\n", - "img_init_steps = 15 # @param {type:\"number\"}\n", - "\n", - "# @markdown The HVP method:\n", - "# @markdown
`forward-functorch` and `reverse` provide real second derivatives. Compatibility, speed, and memory usage vary by model and xformers configuration.\n", - "# @markdown `fake` is very fast and low memory but inaccurate. `zero` (fallback to first order KLMC) is not recommended.\n", - "hvp_method = 'fake' # @param [\"forward-functorch\", \"reverse\", \"fake\", \"zero\"]\n", - "\n", - "checkpoint_every = 10 # @param {type:\"number\"}\n", - "\n", - "###########################\n", - "\n", - "assert (height % 8) == 0\n", - "assert (width % 8) == 0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "1pLTsdGBPXx6", - "tags": [] - }, - "outputs": [], - "source": [ - "#@title Prompts\n", - "\n", - "# [ \n", - "# [\"first prompt will be used to initialize the image\", {time:weight, time:weight...}], \n", - "# [\"more prompts if you want\", {...}], \n", - "# ...]\n", - "\n", - "# if a weight for time=0 isn't specified, the weight is assumed to be zero.\n", - "# if you want to \"fade in\" any prompts, it's best to have them start with a small but non-zero value, e.g. 0.001\n", - "\n", - "prompt_params = [\n", - " # # FIRST PROMPT INITIALIZES IMAGE\n", - " #[\"sweetest puppy, golden retriever\", {0:.5, 30:0.5, 100:0.001}],\n", - " #[\"sweet old dog, golden retriever\", {0:0.001, 30:0.001, 100:0.5}],\n", - " #[\"happiest pupper, cutest dog evar, golden retriever, incredibly photogenic dog\", {0:1}],\n", - "\n", - " # # the 'flowers prompts' below go with a particular 'h' setting in the next cell\n", - " [\"incredibly beautiful orchids, a bouquet of orchids\", {0:1, 35:1, 50:0}],\n", - " [\"incredibly beautiful roses, a bouquet of roses\", {0:0.001, 35:0.001, 50:1, 120:1, 140:0}],\n", - " [\"incredibly beautiful carnations, a bouquet of carnations\", {0:0.001, 120:0.001, 140:1, 220:1, 240:0}],\n", - " [\"incredibly beautiful carnations, a bouquet of sunflowers\", {0:0.001, 220:0.001, 240:1}],\n", - " \n", - " # negative prompts\n", - " [\"watermark text\", {0:-0.1} ],\n", - " [\"jpeg artifacts\", {0:-0.1} ],\n", - " [\"artist's signature\", {0:-0.1} ],\n", - " [\"istockphoto, gettyimages, watermarked image\", {0:-0.1} ],\n", - "]\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "hZ0lh-WkdB19", - "tags": [] - }, - "outputs": [], - "source": [ - "# @title Build prompt and settings objects\n", - "\n", - "# @markdown some advanced features demonstrated in commented-out code in this cell\n", - "\n", - "plot_prompt_weight_curves = True # @param {type: 'boolean'}\n", - "plot_settings_weight_curves = True # @param {type: 'boolean'}\n", - "\n", - "#################\n", - "\n", - "# Build Prompt objects\n", - "\n", - "prompts = [\n", - " Prompt(text, weight_schedule) \n", - " for (text, weight_schedule) in prompt_params\n", - "]\n", - "\n", - "# uncomment to loop the prompts\n", - "#for p in prompts:\n", - "# if len(p.weight.keyframes) > 1: # ignore negative prompts\n", - "# p.weight.loop=True \n", - "\n", - "# uncomment to loop prompts in \"bounce\" mode\n", - "#for p in prompts:\n", - "# if len(p.weight.keyframes) > 1:\n", - "# p.weight.bounce=True \n", - "\n", - "#################\n", - "\n", - "# Build Settings object\n", - "\n", - "g = parse_curvable_string(g)\n", - "sigma = parse_curvable_string(sigma)\n", - "h = parse_curvable_string(h)\n", - "gamma = parse_curvable_string(gamma)\n", - "alpha = parse_curvable_string(alpha)\n", - "tau = parse_curvable_string(tau)\n", - "steps = parse_curvable_string(refinement_steps)\n", - "\n", - "\n", - "curved_settings = ParameterGroup({\n", - " 'g':SmoothCurve(g),\n", - " 'sigma':SmoothCurve(sigma),\n", - " #'h':SmoothCurve(h),\n", - " \n", - " # more concise notation for flowers demo:\n", - " 'h':SmoothCurve({0:0.1, 30:0.1, 50:0.3}, bounce=True),\n", - " #'h':SmoothCurve({0:0.1, 30:0.1, 50:0.3, 70:0.1, 90:0.1}, loop=True),\n", - "\n", - " 'gamma':SmoothCurve(gamma),\n", - " 'alpha':SmoothCurve(alpha),\n", - " 'tau':SmoothCurve(tau),\n", - " 'steps':SmoothCurve(steps),\n", - "})\n", - "\n", - "\n", - "if plot_prompt_weight_curves:\n", - " for prompt in prompts:\n", - " prompt.weight.plot(n=n)\n", - " plt.title(\"prompt weight schedules\")\n", - " plt.show()\n", - "\n", - "\n", - "if plot_settings_weight_curves:\n", - " for name, curve in curved_settings.parameters.items():\n", - " curve.plot(n=n)\n", - " plt.title(name)\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [], - "id": "tthag9k67Uey" - }, - "outputs": [], - "source": [ - "# @markdown running this cell saves the current settings to disk\n", - "\n", - "import keyframed.serialization\n", - "txt = keyframed.serialization.to_yaml(curved_settings, simplify=True)\n", - "\n", - "#print(txt)\n", - "\n", - "# sigma: 1.25\n", - "#\n", - "# becomes:\n", - "#\n", - "# sigma:\n", - "# curve:\n", - "# - - 0\n", - "# - 1.25\n", - "# - eased_lerp\n", - "#\n", - "# :\n", - "# curve:\n", - "# - - \n", - "# - \n", - "# - \n", - "# - \n", - "# - - \n", - "# - \n", - "# - - \n", - "# - \n", - "\n", - "with open(outdir / 'settings.yaml', 'w') as f:\n", - " f.write(txt)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [], - "cellView": "form", - "id": "srbY3kDa7Uey" - }, - "outputs": [], - "source": [ - "# load settings from disk\n", - "\n", - "# @markdown override current settings using the contents of `frames/settings.yaml`\n", - "\n", - "import keyframed.serialization\n", - "\n", - "load_settings_from_disk = True # @param {type:'boolean'}\n", - "\n", - "if load_settings_from_disk:\n", - " with open(outdir / 'settings.yaml', 'r') as f:\n", - " curved_settings = keyframed.serialization.from_yaml(f.read())\n", - "\n", - "curved_settings.to_dict(simplify=True)['parameters']\n", - "#curved_settings.plot()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "i-_u1Q0wRqMb", - "tags": [] - }, - "outputs": [], - "source": [ - "#@title Generate Animation Frames\n", - "\n", - "resume = True # @param {type:'boolean'}\n", - "archive_old_work = False # @param {type:'boolean'}\n", - "\n", - "# -1 = most recent frame\n", - "resume_from = -1 # @param {type:'number'}\n", - "\n", - "# @markdown optional debugging plots\n", - "plot_prompt_weights = True # @param {type:'boolean'}\n", - "plot_h = False # @param {type:'boolean'}\n", - "plot_g = False # @param {type:'boolean'}\n", - "plot_sigma = False # @param {type:'boolean'}\n", - "plot_gamma = False # @param {type:'boolean'}\n", - "plot_alpha = False # @param {type:'boolean'}\n", - "plot_tau = False # @param {type:'boolean'}\n", - "\n", - "################\n", - "\n", - "_seed = seed\n", - "if seed < 0: \n", - " _seed = random.randrange(0, 4294967295)\n", - "print(f\"using seed: {_seed}\")\n", - "torch.manual_seed(_seed)\n", - "\n", - "stuff_to_plot = []\n", - "if plot_prompt_weights:\n", - " stuff_to_plot.append('prompts')\n", - "if plot_h:\n", - " stuff_to_plot.append('h')\n", - "if plot_g:\n", - " stuff_to_plot.append('g')\n", - "if plot_sigma:\n", - " stuff_to_plot.append('sigma')\n", - "if plot_gamma:\n", - " stuff_to_plot.append('gamma')\n", - "if plot_alpha:\n", - " stuff_to_plot.append('alpha')\n", - "if plot_tau:\n", - " stuff_to_plot.append('tau')\n", - "\n", - "if not resume:\n", - " if archive_old_work:\n", - " archive_dir = outdir.parent / 'archive' / str(int(time.time()))\n", - " archive_dir.mkdir(parents=True, exist_ok=True)\n", - " print(f\"Archiving contents of /frames, moving to: {archive_dir}\")\n", - " else:\n", - " print(\"Old contents of /frames being deleted. This can be prevented in the future by setting either 'resume' or 'archive_old_work' to True.\")\n", - " for p in outdir.glob(f'*'):\n", - " if archive_old_work:\n", - " target = archive_dir / p.name\n", - " p.rename(target)\n", - " else:\n", - " p.unlink()\n", - " for p in Path('debug_frames').glob(f'*'):\n", - " p.unlink()\n", - "\n", - "sample_mcmc_klmc2(\n", - " sd_model=sd_model,\n", - " init_image=init_image,\n", - " height=height,\n", - " width=width,\n", - " n=n,\n", - " hvp_method=hvp_method,\n", - " prompts=prompts,\n", - " settings=curved_settings,\n", - " resume=resume,\n", - " resume_from=resume_from,\n", - " img_init_steps=img_init_steps,\n", - " stuff_to_plot=stuff_to_plot,\n", - " checkpoint_every=checkpoint_every,\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "DjwY7XrooLX_" - }, - "outputs": [], - "source": [ - "#@title Make the video\n", - "\n", - "if 'width' not in locals():\n", - " width = height = 512\n", - "\n", - "\n", - "# @markdown If your video is larger than a few MB, attempting to embed it will probably crash\n", - "# @markdown the session. If this happens, view the generated video after downloading it first.\n", - "embed_video = True # @param {type:'boolean'}\n", - "download_video = False # @param {type:'boolean'}\n", - "\n", - "upscale_video = False # @param {type:'boolean'}\n", - "\n", - "\n", - "outdir_str = str(outdir)\n", - "\n", - "fps = 14 # @param {type:\"integer\"}\n", - "out_fname = \"out.mp4\" # @param {type: \"string\"}\n", - "\n", - "out_fullpath = str( outdir / out_fname )\n", - "print(f\"Video will be saved to: {out_fullpath}\")\n", - "\n", - "compile_video_cmd = f\"ffmpeg -y -r {fps} -i 'out_%*.png' -crf 15 -preset veryslow -pix_fmt yuv420p \" # {out_fname}\"\n", - "if upscale_video:\n", - " compile_video_cmd += f\"-vf scale={2*width}x{2*height}:flags=lanczos \"\n", - "compile_video_cmd += f\"{out_fname}\"\n", - "\n", - "print('\\nMaking the video...\\n')\n", - "!cd {outdir_str}; {compile_video_cmd}\n", - "\n", - "\n", - "debug=True\n", - "if debug:\n", - " #outdir_str = \"debug_frames\"\n", - " print(\"\\nMaking debug video...\")\n", - " #!cd debug_frames; ffmpeg -y -r {fps} -i 'prompts_out_%*.png' -crf 15 -preset veryslow -pix_fmt yuv420p debug_out.mp4\n", - " !cd {debug_dir}; ffmpeg -y -r {fps} -i 'debug_out_%*.png' -crf 15 -preset veryslow -pix_fmt yuv420p debug_out.mp4\n", - "\n", - "if embed_video:\n", - " print('\\nThe video:')\n", - " show_video(out_fullpath)\n", - " if debug:\n", - " show_video(debug_dir / \"debug_out.mp4\")\n", - "\n", - "if download_video and probably_using_colab:\n", - " from google.colab import files\n", - " files.download(out_fullpath)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "rK_GlP_7WJiu", - "tags": [] - }, - "outputs": [], - "source": [ - "#@title Licensed under the MIT License { display-mode: \"form\" }\n", - "\n", - "# Copyright (c) 2022 Katherine Crowson \n", - "# Copyright (c) 2023 David Marx \n", - "# Copyright (c) 2022 deforum and contributors\n", - "\n", - "# Permission is hereby granted, free of charge, to any person obtaining a copy\n", - "# of this software and associated documentation files (the \"Software\"), to deal\n", - "# in the Software without restriction, including without limitation the rights\n", - "# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", - "# copies of the Software, and to permit persons to whom the Software is\n", - "# furnished to do so, subject to the following conditions:\n", - "\n", - "# The above copyright notice and this permission notice shall be included in\n", - "# all copies or substantial portions of the Software.\n", - "\n", - "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", - "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", - "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", - "# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", - "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", - "# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n", - "# THE SOFTWARE." - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "private_outputs": true, - "provenance": [] - }, - "gpuClass": "premium", - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - }, - "vscode": { - "interpreter": { - "hash": "ff1624fd81a21ea709585fb1fdce5419f857f6a9e76cb1632f1b8b574978f9ee" - } - } + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "15BNHICpOOXg" + }, + "source": [ + "# Stable Diffusion KLMC2 Animation\n", + "\n", + "
\n", + "\n", + "
\n", + "\n", + "\n", + "Notebook by [Katherine Crowson](https://twitter.com/RiversHaveWings), modified by [David Marx](https://twitter.com/DigThatData).\n", + "\n", + "Sponsored by [StabilityAI](https://twitter.com/stabilityai)\n", + "\n", + "Generate animations with [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) 1.4, using the [KLMC2 discretization of underdamped Langevin dynamics](https://arxiv.org/abs/1807.09382). The notebook is largely inspired by [Ajay Jain](https://twitter.com/ajayj_) and [Ben Poole](https://twitter.com/poolio)'s paper [Journey to the BAOAB-limit](https://www.ajayjain.net/journey)—thank you so much for it!\n", + "\n", + "---\n", + "\n", + "## Modifications Provenance\n", + "\n", + "Original notebook URL - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1m8ovBpO2QilE2o4O-p2PONSwqGn4_x2G)\n", + "\n", + "Features and QOL Modifications by [David Marx](https://twitter.com/DigThatData) - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmarx/notebooks/blob/main/Stable_Diffusion_KLMC2_Animation.ipynb)\n", + "\n", + "Last updated date (effectively, notebook version): 2022-02-20\n", + "\n", + "* Keyframed prompts and settings\n", + "* Multiprompt conditioning w independent prompt schedules\n", + "* Set seed for deterministic output\n", + "* Mount Google Drive\n", + "* Faster Setup\n", + "* Init image\n", + "* Alt-checkpoint loading consistent w/deforum\n", + "* Set output filename\n", + "* Fancy GPU info\n", + "* Video embed optional\n", + "* ~~Cheaper default runtime~~ torn about this\n", + "* Local setup\n", + "* New VAE option\n", + "* Smooth interpolation for settings curves\n", + "* Settings curves specified via simple DSL\n", + "* Exposed `refinement_steps` parameter\n", + "* Custom output resolution\n", + "* Optional video upscale\n", + "* Optional resume, user can specify resumption frame (auto-checkpoints every 10 frames)\n", + "* Optional archival\n", + "* Assorted refactoring\n", + "* Debugging plots and animations\n", + "\n", + "## Local Setup\n", + "\n", + "Download the repo containing this notebook and supplementary setup files.\n", + "\n", + "```\n", + "git clone https://github.com/dmarx/notebooks\n", + "cd notebooks\n", + "```\n", + "\n", + "Strongly recommend setting up and activating a virtual environment first. Here's one option that is built into python, windows users in particular might want to consider using anaconda as an alternative.\n", + "\n", + "```bash\n", + "python3 -m venv _venv\n", + "source _venv/bin/activate\n", + "pip install jupyter\n", + "```\n", + "\n", + "With this venv created, in the future you only need to run `source _venv/bin/activate` to activate it.\n", + "\n", + "You can now start a local jupyter instance from the terminal in which the virtual environment is activated by running the `jupyter` command, or alternatively select the new virtualenv as the python environment in your IDE of choice. When you run the notebook's setup cells, it should detect that local setup needs to be performed and modify its setup procedure appropriately.\n", + "\n", + "A common source of errors is user confusion between the python environment running the notebook and an intended virtual environment into which setup has already been performed. To validate that you are using the python environment you think you are, run the command `which python` (this locates the executable associated with the `python` command) both inside the notebook and in a terminal in which your venv is activated: the results should be identical.\n", + "\n", + "## Contact\n", + "\n", + "Report bugs or feature ideas here: https://github.com/dmarx/notebooks/issues" + ] }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "Ty3IOeXbLzvc", + "tags": [] + }, + "outputs": [], + "source": [ + "#@title Check GPU\n", + "#!nvidia-smi\n", + "\n", + "import pandas as pd\n", + "import subprocess\n", + "\n", + "def gpu_info():\n", + " outv = subprocess.run([\n", + " 'nvidia-smi',\n", + " # these lines concatenate into a single query string\n", + " '--query-gpu='\n", + " 'timestamp,'\n", + " 'name,'\n", + " 'utilization.gpu,'\n", + " 'utilization.memory,'\n", + " 'memory.used,'\n", + " 'memory.free,'\n", + " ,\n", + " '--format=csv'\n", + " ],\n", + " stdout=subprocess.PIPE).stdout.decode('utf-8')\n", + "\n", + " header, rec = outv.split('\\n')[:-1]\n", + " return pd.DataFrame({' '.join(k.strip().split('.')).capitalize():v for k,v in zip(header.split(','), rec.split(','))}, index=[0]).T\n", + "\n", + "gpu_info()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "kelHR9VM1-hg", + "tags": [] + }, + "outputs": [], + "source": [ + "#@title Install Dependencies\n", + "\n", + "# @markdown Your runtime will automatically restart after running this cell.\n", + "# @markdown You should only need to run this cell when setting up a new runtime. After future runtime restarts,\n", + "# @markdown you should be able to skip this cell.\n", + "\n", + "import warnings\n", + "\n", + "probably_using_colab = False\n", + "try:\n", + " import google\n", + " probably_using_colab = True\n", + "except ImportError:\n", + " warnings.warn(\"Unable to import `google`, assuming this means we're using a local runtime\")\n", + "\n", + "# @markdown Not recommended for colab users. This notebook is currently configured to only make this\n", + "# @markdown option available for local install.\n", + "use_xformers = False\n", + "\n", + "try:\n", + " import keyframed\n", + "except ImportError:\n", + " if probably_using_colab:\n", + " !pip install ftfy einops braceexpand requests transformers clip open_clip_torch omegaconf pytorch-lightning kornia k-diffusion ninja omegaconf\n", + " !pip install -U git+https://github.com/huggingface/huggingface_hub\n", + " !pip install napm keyframed\n", + " else:\n", + " !pip install -r klmc2/requirements.txt\n", + " if use_xformers:\n", + " !pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers\n", + "\n", + " exit() # restarts the runtime" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "fJZtXShcPXx5", + "tags": [] + }, + "outputs": [], + "source": [ + "# @markdown # Setup Workspace { display-mode: \"form\" }\n", + "\n", + "###################\n", + "# Setup Workspace #\n", + "###################\n", + "\n", + "import os\n", + "from pathlib import Path\n", + "import warnings\n", + "\n", + "probably_using_colab = False\n", + "try:\n", + " import google\n", + " if Path('/content').exists():\n", + " probably_using_colab = True\n", + " print(\"looks like we're in colab\")\n", + " else:\n", + " print(\"looks like we're not in colab\")\n", + "except ImportError:\n", + " warnings.warn(\"Unable to import `google`, assuming this means we're using a local runtime\")\n", + "\n", + "\n", + "mount_gdrive = True # @param {type:'boolean'}\n", + "\n", + "# defaults\n", + "outdir = Path('./frames')\n", + "if not os.environ.get('XDG_CACHE_HOME'):\n", + " os.environ['XDG_CACHE_HOME'] = str(Path('~/.cache').expanduser())\n", + "\n", + "if mount_gdrive and probably_using_colab:\n", + " from google.colab import drive\n", + " drive.mount('/content/drive')\n", + " Path('/content/drive/MyDrive/AI/models/.cache/').mkdir(parents=True, exist_ok=True) \n", + " os.environ['XDG_CACHE_HOME']='/content/drive/MyDrive/AI/models/.cache'\n", + " outdir = Path('/content/drive/MyDrive/AI/klmc2/frames/')\n", + "\n", + "# make sure the paths we need exist\n", + "outdir.mkdir(parents=True, exist_ok=True)\n", + "debug_dir = outdir.parent / 'debug_frames'\n", + "debug_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + "\n", + "os.environ['NAPM_PATH'] = str( Path(os.environ['XDG_CACHE_HOME']) / 'napm' )\n", + "Path(os.environ['NAPM_PATH']).mkdir(parents=True, exist_ok=True)\n", + "\n", + "\n", + "import napm\n", + "\n", + "url = 'https://github.com/Stability-AI/stablediffusion'\n", + "napm.pseudoinstall_git_repo(url, add_install_dir_to_path=True)\n", + "\n", + "\n", + "##### Moved from model loading cell\n", + "\n", + "if probably_using_colab:\n", + " models_path = \"/content/models\" #@param {type:\"string\"}\n", + "else:\n", + " models_path = os.environ['XDG_CACHE_HOME']\n", + "\n", + "if mount_gdrive and probably_using_colab:\n", + " models_path_gdrive = \"/content/drive/MyDrive/AI/models\" #@param {type:\"string\"}\n", + " models_path = models_path_gdrive\n", + "\n", + "if not Path(models_path).exists():\n", + " Path(models_path).mkdir(parents=True, exist_ok=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "y2jXKIf2ZkT8", + "tags": [] + }, + "outputs": [], + "source": [ + "# @markdown # Imports and Definitions { display-mode: \"form\" }\n", + "\n", + "###########\n", + "# imports #\n", + "###########\n", + "\n", + "# importing napm puts the stable diffusion repo on the PATH, which is where `ldm` imports from\n", + "import napm\n", + "from ldm.util import instantiate_from_config\n", + "\n", + "from base64 import b64encode\n", + "from collections import defaultdict\n", + "from concurrent import futures\n", + "import math\n", + "from pathlib import Path\n", + "import random\n", + "import re\n", + "import requests\n", + "from requests.exceptions import HTTPError\n", + "import sys\n", + "import time\n", + "from urllib.parse import urlparse\n", + "import warnings\n", + "\n", + "import functorch\n", + "import huggingface_hub\n", + "from IPython.display import display, Video, HTML\n", + "import k_diffusion as K\n", + "from keyframed import Curve, ParameterGroup, SmoothCurve\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np \n", + "from omegaconf import OmegaConf\n", + "import PIL\n", + "from PIL import Image\n", + "import torch\n", + "from torch import nn\n", + "from tqdm.auto import tqdm, trange\n", + "\n", + "from loguru import logger\n", + "import time\n", + "from natsort import natsorted\n", + "\n", + "\n", + "#########################\n", + "# Define useful globals #\n", + "#########################\n", + "\n", + "cpu = torch.device(\"cpu\")\n", + "device = torch.device(\"cuda\")\n", + "\n", + "\n", + "##############################\n", + "# Define necessary functions #\n", + "##############################\n", + " \n", + "import matplotlib.image\n", + "import numpy as np\n", + " \n", + "def get_latest_frame(i=None, latest_frame_fpath=None):\n", + " latest_frame = None\n", + " if latest_frame_fpath is not None:\n", + " latest_frame = latest_frame_fpath\n", + " \n", + " if (latest_frame is None) and (i is None):\n", + " frames = Path('frames').glob(\"*.png\")\n", + " #logger.debug(f\"frames: {len(frames)}\")\n", + " latest_frame = natsort.sort(frames)[-1]\n", + " i = re.findall('out_([0-9]+).png', latest_frame.name)\n", + " else:\n", + " latest_frame = Path('frames') / f\"out_{i:05}.png\"\n", + " logger.debug(f'-latest_frame: {latest_frame}')\n", + " #return Image.open(latest_frame)\n", + " img = matplotlib.image.imread(latest_frame)\n", + " return np.flip(img, axis=0) # up/down\n", + "\n", + "def plot_prompts(prompts=None, n=1000, settings=None, **kargs):\n", + " if prompts is not None:\n", + " for prompt in prompts:\n", + " prompt.weight.plot(n=n, **kargs)\n", + "\n", + "def plot_param(param, settings=None, prompts=None, n=1000, **kargs):\n", + " settings.parameters[param].plot(n=n, **kargs)\n", + " \n", + "# move imports up\n", + "import base64\n", + "from io import BytesIO\n", + "from functools import partial\n", + " \n", + "@logger.catch\n", + "def write_debug_frame_at_(\n", + " i=None,\n", + " n=300, \n", + " prompts=None, \n", + " stuff_to_plot=['prompts'], \n", + " latest_frame_fpath=None,\n", + " pil_image=None,\n", + " settings=None,\n", + "):\n", + " plotting_funcs = {\n", + " 'prompts': plot_prompts,\n", + " 'g': partial(plot_param, param='g'),\n", + " 'h': partial(plot_param, param='h'),\n", + " 'sigma': partial(plot_param, param='sigma'),\n", + " 'gamma': partial(plot_param, param='gamma'),\n", + " 'alpha': partial(plot_param, param='alpha'),\n", + " 'tau': partial(plot_param, param='tau'),\n", + " }\n", + " \n", + " # i feel like this line of code justifies the silly variable name\n", + " if not stuff_to_plot:\n", + " return\n", + " \n", + " #stuff_to_plot = []\n", + " \n", + " test_im = pil_image\n", + " if pil_image is None:\n", + " test_im = get_latest_frame(i, latest_frame_fpath)\n", + "\n", + " fig = plt.figure()\n", + " #axsRight = fig.subplots(3, 1, sharex=True)\n", + " #ax = axsRight[0]\n", + " ax_objs = fig.subplots(len(stuff_to_plot), 1, sharex=True)\n", + " \n", + " #width, height = test_im.size\n", + " height, width = test_im.size\n", + " fig.set_size_inches(height/fig.dpi, width/fig.dpi )\n", + " \n", + " buffer = BytesIO()\n", + " for j, category in enumerate(stuff_to_plot):\n", + " ax = ax_objs\n", + " if len(stuff_to_plot) > 1:\n", + " ax = ax_objs[j]\n", + " plt.sca(ax)\n", + " plt.tight_layout()\n", + " plt.axis('off')\n", + " \n", + " plotting_funcs[category](prompts=prompts, settings=settings, n=n, zorder=1)\n", + " plt.axvline(x=i)\n", + " \n", + " \n", + "\n", + " #plt.margins(0)\n", + " fig.savefig(buffer, transparent=True) \n", + " plt.close()\n", + "\n", + " buffer.seek(0)\n", + " plot_pil = Image.open(buffer)\n", + " #buffer.close() # throws error here\n", + "\n", + " #debug_im_path = Path('debug_frames') / f\"{category}_out_{i:05}.png\"\n", + " #debug_im_path = Path('debug_frames') / f\"debug_out_{i:05}.png\"\n", + " debug_im_path = debug_dir / f\"debug_out_{i:05}.png\"\n", + " test_im = test_im.convert('RGBA')\n", + " test_im.paste(plot_pil, (0,0), plot_pil)\n", + " test_im.save(debug_im_path)\n", + " #display(test_im) # maybe?\n", + " buffer.close() # I guess?\n", + " \n", + " return test_im, plot_pil\n", + "\n", + "##############################\n", + "\n", + "class Prompt:\n", + " def __init__(\n", + " self,\n", + " text,\n", + " weight_schedule,\n", + " ):\n", + " c = sd_model.get_learned_conditioning([text])\n", + " self.text=text\n", + " self.encoded=c\n", + " self.weight = SmoothCurve(weight_schedule)\n", + "\n", + "\n", + "def handle_chigozienri_curve_format(value_string):\n", + " if value_string.startswith('(') and value_string.endswith(')'):\n", + " value_string = value_string[1:-1]\n", + " return value_string\n", + "\n", + "def parse_curve_string(txt, f=float):\n", + " schedule = {}\n", + " for tokens in txt.split(','):\n", + " k,v = tokens.split(':')\n", + " v = handle_chigozienri_curve_format(v)\n", + " schedule[int(k)] = f(v)\n", + " return schedule\n", + "\n", + "def parse_curvable_string(param, is_int=False):\n", + " if isinstance(param, dict):\n", + " return param\n", + " f = float\n", + " if is_int:\n", + " f = int\n", + " try:\n", + " return f(param)\n", + " except ValueError:\n", + " return parse_curve_string(txt=param, f=f)\n", + "\n", + "##################\n", + "\n", + "def show_video(video_path, video_width=512):\n", + " return display(Video(video_path, width=video_width))\n", + "\n", + "if probably_using_colab:\n", + " def show_video(video_path, video_width=512):\n", + " video_file = open(video_path, \"r+b\").read()\n", + " video_url = f\"data:video/mp4;base64,{b64encode(video_file).decode()}\"\n", + " return display(HTML(f\"\"\"\"\"\"))\n", + "\n", + "##################\n", + "\n", + "class NormalizingCFGDenoiser(nn.Module):\n", + " def __init__(self, model, g):\n", + " super().__init__()\n", + " self.inner_model = model\n", + " self.g = g\n", + " self.eps_norms = defaultdict(lambda: (0, 0))\n", + "\n", + " def mean_sq(self, x):\n", + " return x.pow(2).flatten(1).mean(1)\n", + "\n", + " @torch.no_grad()\n", + " def update_eps_norm(self, eps, sigma):\n", + " sigma = sigma[0].item()\n", + " eps_norm = self.mean_sq(eps).mean()\n", + " eps_norm_avg, count = self.eps_norms[sigma]\n", + " eps_norm_avg = eps_norm_avg * count / (count + 1) + eps_norm / (count + 1)\n", + " self.eps_norms[sigma] = (eps_norm_avg, count + 1)\n", + " return eps_norm_avg\n", + "\n", + " def forward(self, x, sigma, uncond, cond, g):\n", + " x_in = torch.cat([x] * 2)\n", + " sigma_in = torch.cat([sigma] * 2)\n", + " cond_in = torch.cat([uncond, cond])\n", + "\n", + " denoised = self.inner_model(x_in, sigma_in, cond=cond_in)\n", + " eps = K.sampling.to_d(x_in, sigma_in, denoised)\n", + " eps_uc, eps_c = eps.chunk(2)\n", + " eps_norm = self.update_eps_norm(eps, sigma).sqrt()\n", + " c = eps_c - eps_uc\n", + " cond_scale = g * eps_norm / self.mean_sq(c).sqrt()\n", + " eps_final = eps_uc + c * K.utils.append_dims(cond_scale, x.ndim)\n", + " return x - eps_final * K.utils.append_dims(sigma, eps.ndim)\n", + "\n", + "#########################\n", + "\n", + "def write_klmc2_state(**state):\n", + " st = time.time()\n", + " obj = {}\n", + " for k,v in state.items():\n", + " try:\n", + " v = v.clone().detach().cpu()\n", + " except AttributeError:\n", + " # if it doesn't have a detach method, we don't need to worry about any preprocessing\n", + " pass\n", + " obj[k] = v\n", + "\n", + " checkpoint_fpath = Path(outdir) / f\"klmc2_state_{state.get('i',0):05}.ckpt\"\n", + " with open(checkpoint_fpath, 'wb') as f:\n", + " torch.save(obj, f=f)\n", + " et = time.time()\n", + " #logger.debug(f\"checkpointing: {et-st}\")\n", + " # to do: move to callback? thread executor, anyway\n", + "\n", + "def read_klmc2_state(root=outdir, latest_frame=-1):\n", + " state = {}\n", + " checkpoints = [str(p) for p in Path(root).glob(\"*.ckpt\")]\n", + " if not checkpoints:\n", + " return None\n", + " checkpoints = natsorted(checkpoints)\n", + " if latest_frame < 0:\n", + " ckpt_fpath = checkpoints[-1]\n", + " else:\n", + " for fname in checkpoints:\n", + " frame_id = re.findall(r'([0-9]+).ckpt', fname)[0]\n", + " if int(frame_id) <= latest_frame:\n", + " ckpt_fpath = fname\n", + " else:\n", + " break\n", + " logger.debug(ckpt_fpath)\n", + " with open(ckpt_fpath,'rb') as f:\n", + " state = torch.load(f=f,map_location='cuda')\n", + " return state\n", + "\n", + "def load_init_image(init_image, height, width):\n", + " if not Path(init_image).exists():\n", + " raise FileNotFoundError(f\"Unable to locate init image from path: {init_image}\")\n", + " \n", + " \n", + " from PIL import Image\n", + " import numpy as np\n", + "\n", + " init_im_pil = Image.open(init_image)\n", + "\n", + " #x_pil = init_im_pil.resize([512,512])\n", + " x_pil = init_im_pil.resize([height,width])\n", + " x_np = np.array(x_pil.convert('RGB')).astype(np.float16) / 255.0\n", + " x = x_np[None].transpose(0, 3, 1, 2)\n", + " x = 2.*x - 1.\n", + " x = torch.from_numpy(x).to('cuda')\n", + " return x\n", + "\n", + "def save_image_fn(image, name, i, n, prompts=None, settings=None, stuff_to_plot=['prompts']):\n", + " pil_image = K.utils.to_pil_image(image)\n", + " if i % 10 == 0 or i == n - 1:\n", + " print(f'\\nIteration {i}/{n}:')\n", + " display(pil_image)\n", + " if i == n - 1:\n", + " print('\\nDone!')\n", + " pil_image.save(name)\n", + " if stuff_to_plot:\n", + " #logger.debug(stuff_to_plot)\n", + " #write_debug_frame_at_(i, prompts=prompts)\n", + " debug_frame, debug_plot = write_debug_frame_at_(i=i,n=n, prompts=prompts, settings=settings, stuff_to_plot=stuff_to_plot, pil_image=pil_image)\n", + " if i % 10 == 0 or i == n - 1:\n", + " #display(debug_frame)\n", + " display(debug_plot)\n", + "\n", + "###############################\n", + "\n", + "@torch.no_grad()\n", + "def sample_mcmc_klmc2(\n", + " sd_model, \n", + " init_image,\n", + " height:int,\n", + " width:int,\n", + " n:int, \n", + " hvp_method:str='reverse', \n", + " prompts:list=None,\n", + " settings:ParameterGroup=None,\n", + " resume:bool = False,\n", + " resume_from:int=-1,\n", + " img_init_steps:int=None,\n", + " stuff_to_plot:list=None,\n", + " checkpoint_every:int=10,\n", + "):\n", + "\n", + " if stuff_to_plot is None:\n", + " stuff_to_plot = ['prompts','h']\n", + " \n", + " torch.cuda.empty_cache()\n", + "\n", + " wrappers = {'eps': K.external.CompVisDenoiser, 'v': K.external.CompVisVDenoiser}\n", + " g = settings[0]['g']\n", + "\n", + " model_wrap = wrappers[sd_model.parameterization](sd_model)\n", + " model_wrap_cfg = NormalizingCFGDenoiser(model_wrap, g)\n", + " sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()\n", + " model = model_wrap_cfg\n", + "\n", + " uc = sd_model.get_learned_conditioning([''])\n", + " extra_args = {\n", + " 'uncond': uc, \n", + " #'cond': prompts[0].encoded,\n", + " 'g': settings[0]['g']\n", + " }\n", + "\n", + " sigma = settings[0]['sigma']\n", + "\n", + " with torch.cuda.amp.autocast(), futures.ThreadPoolExecutor() as ex:\n", + " def callback(info):\n", + " i = info['i']\n", + " rgb = sd_model.decode_first_stage(info['denoised'] )\n", + " ex.submit(save_image_fn, image=rgb, name=(outdir / f'out_{i:05}.png'), i=i, n=n, prompts=prompts, settings=settings, stuff_to_plot=stuff_to_plot)\n", + "\n", + " # Initialize the chain\n", + " print('Initializing the chain...')\n", + "\n", + " # to do: if resuming, generating this init image is unnecessary\n", + " x = None\n", + " if init_image:\n", + " print(\"loading init image\")\n", + " x = load_init_image(init_image, height, width)\n", + " # convert RGB to latent\n", + " x = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(x))\n", + " print(\"init image loaded.\")\n", + "\n", + " print('Actually doing the sampling...')\n", + "\n", + " i_resume=0\n", + " v = None\n", + " if resume:\n", + " state = read_klmc2_state(latest_frame=resume_from)\n", + " if state:\n", + " x, v, i_resume = state['x'], state['v'], state['i']\n", + " # to do: resumption of settings\n", + " settings_i = state['settings_i']\n", + " settings[i]['h'] = settings_i['h']\n", + " settings[i]['gamma'] = settings_i['gamma']\n", + " settings[i]['alpha'] = settings_i['alpha']\n", + " settings[i]['tau'] = settings_i['tau']\n", + " settings[i]['g'] = settings_i['g']\n", + " settings[i]['sigma'] = settings_i['sigma']\n", + " settings[i]['steps'] = settings_i['steps']\n", + " \n", + " # to do: use multicond for init image\n", + " # we want this test after resumption if resuming\n", + " if x is None:\n", + " print(\"No init image provided, generating a random init image\")\n", + " extra_args['cond'] = prompts[0].encoded\n", + " h=height//8\n", + " w=width//8\n", + " x = torch.randn([1, 4, h, w], device=device) * sigma_max\n", + " sigmas_pre = K.sampling.get_sigmas_karras(img_init_steps, sigma, sigma_max, device=x.device)[:-1]\n", + " x = K.sampling.sample_dpmpp_sde(model_wrap_cfg, x, sigmas_pre, extra_args=extra_args)\n", + "\n", + " # if not resuming, randomly initialize momentum\n", + " # this needs to be *after* generating X if we're going to...\n", + " if v is None:\n", + " v = torch.randn_like(x) * sigma\n", + "\n", + " # main sampling loop\n", + " for i in trange(n):\n", + " # fast-forward loop to resumption index\n", + " if resume and i < i_resume:\n", + " continue\n", + " # if resume and (i == i_resume):\n", + " # # should these values be written into settings[i]?\n", + " # h = settings_i['h']\n", + " # gamma = settings_i['gamma']\n", + " # alpha = settings_i['alpha']\n", + " # tau = settings_i['tau']\n", + " # g = settings_i['g']\n", + " # sigma = settings_i['sigma']\n", + " # steps = settings_i['steps']\n", + " # else:\n", + " h = settings[i]['h']\n", + " gamma = settings[i]['gamma']\n", + " alpha = settings[i]['alpha']\n", + " tau = settings[i]['tau']\n", + " g = settings[i]['g']\n", + " sigma = settings[i]['sigma']\n", + " steps = settings[i]['steps']\n", + "\n", + " h = torch.tensor(h, device=x.device)\n", + " gamma = torch.tensor(gamma, device=x.device)\n", + " alpha = torch.tensor(alpha, device=x.device)\n", + " tau = torch.tensor(tau, device=x.device)\n", + " sigma = torch.tensor(sigma, device=x.device)\n", + " steps = int(steps)\n", + " \n", + " sigmas = K.sampling.get_sigmas_karras(steps, sigma_min, sigma.item(), device=x.device)[:-1]\n", + "\n", + " x, v, grad = klmc2_step(\n", + " model,\n", + " prompts,\n", + " x,\n", + " v,\n", + " h,\n", + " gamma,\n", + " alpha,\n", + " tau,\n", + " g,\n", + " sigma,\n", + " sigmas,\n", + " steps,\n", + " hvp_method,\n", + " i,\n", + " callback,\n", + " extra_args,\n", + " )\n", + "\n", + " save_checkpoint = (i % checkpoint_every) == 0\n", + " if save_checkpoint:\n", + " settings_i = settings[i]\n", + " ex.submit(write_klmc2_state, v=v, x=x, i=i, settings_i=settings_i)\n", + " logger.debug(settings[i])\n", + "\n", + "\n", + "def hvp_fn_forward_functorch(model, x, sigma, v, alpha, extra_args):\n", + " def grad_fn(x, sigma):\n", + " s_in = x.new_ones([x.shape[0]])\n", + " denoised = model(x, sigma * s_in, **extra_args)\n", + " return (x - denoised) + alpha * x\n", + " jvp_fn = lambda v: functorch.jvp(grad_fn, (x, sigma), (v, torch.zeros_like(sigma)))\n", + " grad, jvp_out = functorch.vmap(jvp_fn)(v)\n", + " return grad[0], jvp_out\n", + "\n", + "def hvp_fn_reverse(model, x, sigma, v, alpha, extra_args):\n", + " def grad_fn(x, sigma):\n", + " s_in = x.new_ones([x.shape[0]])\n", + " denoised = model(x, sigma * s_in, **extra_args)\n", + " return (x - denoised) + alpha * x\n", + " vjps = []\n", + " with torch.enable_grad():\n", + " x_ = x.clone().requires_grad_()\n", + " grad = grad_fn(x_, sigma)\n", + " for k, item in enumerate(v):\n", + " vjp_out = torch.autograd.grad(grad, x_, item, retain_graph=k < len(v) - 1)[0]\n", + " vjps.append(vjp_out)\n", + " return grad, torch.stack(vjps)\n", + "\n", + "def hvp_fn_zero(model, x, sigma, v, alpha, extra_args):\n", + " def grad_fn(x, sigma):\n", + " s_in = x.new_ones([x.shape[0]])\n", + " denoised = model(x, sigma * s_in, **extra_args)\n", + " return (x - denoised) + alpha * x\n", + " return grad_fn(x, sigma), torch.zeros_like(v)\n", + "\n", + "def hvp_fn_fake(model, x, sigma, v, alpha, extra_args):\n", + " def grad_fn(x, sigma):\n", + " s_in = x.new_ones([x.shape[0]])\n", + " denoised = model(x, sigma * s_in, **extra_args)\n", + " return (x - denoised) + alpha * x\n", + " return grad_fn(x, sigma), (1 + alpha) * v\n", + "\n", + "\n", + "def multicond_hvp(model, x, sigma, v, alpha, extra_args, prompts, hvp_fn, i):\n", + "\n", + " # loop over prompts and aggregate gradients for multicond\n", + " grad = torch.zeros_like(x)\n", + " h2_v = torch.zeros_like(x)\n", + " h2_noise_v2 = torch.zeros_like(x)\n", + " h2_noise_x2 = torch.zeros_like(x)\n", + " wt_norm = 0\n", + " for prompt in prompts:\n", + " wt = prompt.weight[i]\n", + " if wt == 0:\n", + " continue\n", + " wt_norm += wt\n", + " wt = torch.tensor(wt, device=x.device)\n", + " extra_args['cond'] = prompt.encoded\n", + "\n", + " # Estimate gradient and hessian\n", + " grad_, (h2_v_, h2_noise_v2_, h2_noise_x2_) = hvp_fn(\n", + " model=model,\n", + " x=x, \n", + " sigma=sigma, \n", + " v=v,\n", + " alpha=alpha,\n", + " extra_args=extra_args,\n", + " )\n", + "\n", + " grad = grad + grad_ * wt \n", + " h2_v = h2_v + h2_v_ * wt\n", + " h2_noise_v2 = h2_noise_v2 + h2_noise_v2_ * wt\n", + " h2_noise_x2 = h2_noise_x2 + h2_noise_x2_ * wt\n", + "\n", + " # Normalize gradient to magnitude it'd have if just single prompt w/ wt=1.\n", + " # simplifies multicond w/o deep frying image or adding hyperparams\n", + " grad = grad / wt_norm \n", + " h2_v = h2_v / wt_norm\n", + " h2_noise_v2 = h2_noise_v2 / wt_norm\n", + " h2_noise_x2 = h2_noise_x2 / wt_norm\n", + "\n", + " return grad, h2_v, h2_noise_v2, h2_noise_v2, h2_noise_x2\n", + "\n", + "\n", + "\n", + "\n", + "def klmc2_step(\n", + " model,\n", + " prompts,\n", + " x,\n", + " v,\n", + " h,\n", + " gamma,\n", + " alpha,\n", + " tau,\n", + " g,\n", + " sigma,\n", + " sigmas,\n", + " steps,\n", + " hvp_method,\n", + " i,\n", + " callback,\n", + " extra_args,\n", + " ):\n", + "\n", + " #s_in = x.new_ones([x.shape[0]])\n", + "\n", + " # Model helper functions\n", + "\n", + " hvp_fns = {'forward-functorch': hvp_fn_forward_functorch,\n", + " 'reverse': hvp_fn_reverse,\n", + " 'zero': hvp_fn_zero,\n", + " 'fake': hvp_fn_fake}\n", + "\n", + " hvp_fn = hvp_fns[hvp_method]\n", + "\n", + " # KLMC2 helper functions\n", + " def psi_0(gamma, t):\n", + " return torch.exp(-gamma * t)\n", + "\n", + " def psi_1(gamma, t):\n", + " return -torch.expm1(-gamma * t) / gamma\n", + "\n", + " def psi_2(gamma, t):\n", + " return (torch.expm1(-gamma * t) + gamma * t) / gamma ** 2\n", + "\n", + " def phi_2(gamma, t_):\n", + " t = t_.double()\n", + " out = (torch.exp(-gamma * t) * (torch.expm1(gamma * t) - gamma * t)) / gamma ** 2\n", + " return out.to(t_)\n", + "\n", + " def phi_3(gamma, t_):\n", + " t = t_.double()\n", + " out = (torch.exp(-gamma * t) * (2 + gamma * t + torch.exp(gamma * t) * (gamma * t - 2))) / gamma ** 3\n", + " return out.to(t_)\n", + "\n", + "\n", + " # Compute model outputs and sample noise\n", + " x_trapz = torch.linspace(0, h, 1001, device=x.device)\n", + " y_trapz = [fun(gamma, x_trapz) for fun in (psi_0, psi_1, phi_2, phi_3)]\n", + " noise_cov = torch.tensor([[torch.trapz(y_trapz[i] * y_trapz[j], x=x_trapz) for j in range(4)] for i in range(4)], device=x.device)\n", + " noise_v, noise_x, noise_v2, noise_x2 = torch.distributions.MultivariateNormal(x.new_zeros([4]), noise_cov).sample(x.shape).unbind(-1)\n", + "\n", + " extra_args['g']=g\n", + "\n", + " # compute derivatives, multicond wrapper loops over prompts and averages derivatives\n", + " grad, h2_v, h2_noise_v2, h2_noise_v2, h2_noise_x2 = multicond_hvp(\n", + " model=model, \n", + " x=x, \n", + " sigma=sigma, \n", + " v=torch.stack([v, noise_v2, noise_x2]), # need a \"dummy\" v for init image generation\n", + " alpha=alpha, \n", + " extra_args=extra_args, \n", + " prompts=prompts, \n", + " hvp_fn=hvp_fn,\n", + " i=i,\n", + " )\n", + "\n", + " # DPM-Solver++(2M) refinement steps\n", + " x_refine = x\n", + " use_dpm = True\n", + " old_denoised = None\n", + " for j in range(len(sigmas) - 1):\n", + " if j == 0:\n", + " denoised = x_refine - grad\n", + " else:\n", + " s_in = x.new_ones([x.shape[0]])\n", + " denoised = model(x_refine, sigmas[j] * s_in, **extra_args)\n", + " dt_ode = sigmas[j + 1] - sigmas[j]\n", + " if not use_dpm or old_denoised is None or sigmas[j + 1] == 0:\n", + " eps = K.sampling.to_d(x_refine, sigmas[j], denoised)\n", + " x_refine = x_refine + eps * dt_ode\n", + " else:\n", + " h_ode = sigmas[j].log() - sigmas[j + 1].log()\n", + " h_last = sigmas[j - 1].log() - sigmas[j].log()\n", + " fac = h_ode / (2 * h_last)\n", + " denoised_d = (1 + fac) * denoised - fac * old_denoised\n", + " eps = K.sampling.to_d(x_refine, sigmas[j], denoised_d)\n", + " x_refine = x_refine + eps * dt_ode\n", + " old_denoised = denoised\n", + " if callback is not None:\n", + " callback({'i': i, 'denoised': x_refine})\n", + "\n", + " # Update the chain\n", + " noise_std = (2 * gamma * tau * sigma ** 2).sqrt()\n", + " v_next = 0 + psi_0(gamma, h) * v - psi_1(gamma, h) * grad - phi_2(gamma, h) * h2_v + noise_std * (noise_v - h2_noise_v2)\n", + " x_next = x + psi_1(gamma, h) * v - psi_2(gamma, h) * grad - phi_3(gamma, h) * h2_v + noise_std * (noise_x - h2_noise_x2)\n", + " v, x = v_next, x_next\n", + "\n", + " return x, v, grad " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "yt3d1hww17ST", + "tags": [] + }, + "outputs": [], + "source": [ + "#@markdown **Select and Load Model**\n", + "\n", + "## TO DO:\n", + "## - if local, try to load model from ~/.cache/huggingface/diffusers\n", + "\n", + "# modified from:\n", + "# https://github.com/deforum/stable-diffusion/blob/main/Deforum_Stable_Diffusion.ipynb\n", + "\n", + "import napm\n", + "from ldm.util import instantiate_from_config\n", + "\n", + "\n", + "model_config = \"v1-inference.yaml\" #@param [\"custom\",\"v1-inference.yaml\"]\n", + "model_checkpoint = \"sd-v1-4.ckpt\" #@param [\"custom\",\"sd-v1-4-full-ema.ckpt\",\"sd-v1-4.ckpt\",\"sd-v1-3-full-ema.ckpt\",\"sd-v1-3.ckpt\",\"sd-v1-2-full-ema.ckpt\",\"sd-v1-2.ckpt\",\"sd-v1-1-full-ema.ckpt\",\"sd-v1-1.ckpt\", \"robo-diffusion-v1.ckpt\",\"waifu-diffusion-v1-3.ckpt\"]\n", + "if model_checkpoint == \"waifu-diffusion-v1-3.ckpt\":\n", + " model_checkpoint = \"model-epoch05-float16.ckpt\"\n", + "custom_config_path = \"\" #@param {type:\"string\"}\n", + "custom_checkpoint_path = \"\" #@param {type:\"string\"}\n", + "\n", + "half_precision = True # check\n", + "check_sha256 = False #@param {type:\"boolean\"}\n", + "\n", + "model_map = {\n", + " \"sd-v1-4-full-ema.ckpt\": {\n", + " 'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-4-full-ema.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-4.ckpt\": {\n", + " 'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-3-full-ema.ckpt\": {\n", + " 'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/blob/main/sd-v1-3-full-ema.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-3.ckpt\": {\n", + " 'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/resolve/main/sd-v1-3.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-2-full-ema.ckpt\": {\n", + " 'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-2-full-ema.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-2.ckpt\": {\n", + " 'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/resolve/main/sd-v1-2.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-1-full-ema.ckpt\": {\n", + " 'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829',\n", + " 'url':'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1-full-ema.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-1.ckpt\": {\n", + " 'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"robo-diffusion-v1.ckpt\": {\n", + " 'sha256': '244dbe0dcb55c761bde9c2ac0e9b46cc9705ebfe5f1f3a7cc46251573ea14e16',\n", + " 'url': 'https://huggingface.co/nousr/robo-diffusion/resolve/main/models/robo-diffusion-v1.ckpt',\n", + " 'requires_login': False,\n", + " },\n", + " \"model-epoch05-float16.ckpt\": {\n", + " 'sha256': '26cf2a2e30095926bb9fd9de0c83f47adc0b442dbfdc3d667d43778e8b70bece',\n", + " 'url': 'https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/model-epoch05-float16.ckpt',\n", + " 'requires_login': False,\n", + " },\n", + "}\n", + "\n", + "# config path\n", + "ckpt_config_path = custom_config_path if model_config == \"custom\" else os.path.join(models_path, model_config)\n", + "if os.path.exists(ckpt_config_path):\n", + " print(f\"{ckpt_config_path} exists\")\n", + "else:\n", + " #ckpt_config_path = \"./stable-diffusion/configs/stable-diffusion/v1-inference.yaml\"\n", + " ckpt_config_path = \"./v1-inference.yaml\"\n", + " if not Path(ckpt_config_path).exists():\n", + " !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml\n", + " \n", + "print(f\"Using config: {ckpt_config_path}\")\n", + "\n", + "# checkpoint path or download\n", + "ckpt_path = custom_checkpoint_path if model_checkpoint == \"custom\" else os.path.join(models_path, model_checkpoint)\n", + "ckpt_valid = True\n", + "if os.path.exists(ckpt_path):\n", + " print(f\"{ckpt_path} exists\")\n", + "elif 'url' in model_map[model_checkpoint]:\n", + " url = model_map[model_checkpoint]['url']\n", + "\n", + " # CLI dialogue to authenticate download\n", + " if model_map[model_checkpoint]['requires_login']:\n", + " print(\"This model requires an authentication token\")\n", + " print(\"Please ensure you have accepted its terms of service before continuing.\")\n", + "\n", + " username = input(\"What is your huggingface username?:\")\n", + " token = input(\"What is your huggingface token?:\")\n", + "\n", + " _, path = url.split(\"https://\")\n", + "\n", + " url = f\"https://{username}:{token}@{path}\"\n", + "\n", + " # contact server for model\n", + " print(f\"Attempting to download {model_checkpoint}...this may take a while\")\n", + " ckpt_request = requests.get(url)\n", + " request_status = ckpt_request.status_code\n", + "\n", + " # inform user of errors\n", + " if request_status == 403:\n", + " raise ConnectionRefusedError(\"You have not accepted the license for this model.\")\n", + " elif request_status == 404:\n", + " raise ConnectionError(\"Could not make contact with server\")\n", + " elif request_status != 200:\n", + " raise ConnectionError(f\"Some other error has ocurred - response code: {request_status}\")\n", + "\n", + " # write to model path\n", + " with open(os.path.join(models_path, model_checkpoint), 'wb') as model_file:\n", + " model_file.write(ckpt_request.content)\n", + "else:\n", + " print(f\"Please download model checkpoint and place in {os.path.join(models_path, model_checkpoint)}\")\n", + " ckpt_valid = False\n", + "\n", + "if check_sha256 and model_checkpoint != \"custom\" and ckpt_valid:\n", + " import hashlib\n", + " print(\"\\n...checking sha256\")\n", + " with open(ckpt_path, \"rb\") as f:\n", + " bytes = f.read() \n", + " hash = hashlib.sha256(bytes).hexdigest()\n", + " del bytes\n", + " if model_map[model_checkpoint][\"sha256\"] == hash:\n", + " print(\"hash is correct\\n\")\n", + " else:\n", + " print(\"hash in not correct\\n\")\n", + " ckpt_valid = False\n", + "\n", + "if ckpt_valid:\n", + " print(f\"Using ckpt: {ckpt_path}\")\n", + "\n", + "def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True):\n", + " map_location = \"cuda\" #@param [\"cpu\", \"cuda\"]\n", + " print(f\"Loading model from {ckpt}\")\n", + " pl_sd = torch.load(ckpt, map_location=map_location)\n", + " if \"global_step\" in pl_sd:\n", + " print(f\"Global Step: {pl_sd['global_step']}\")\n", + " sd = pl_sd[\"state_dict\"]\n", + " model = instantiate_from_config(config.model)\n", + " m, u = model.load_state_dict(sd, strict=False)\n", + " if len(m) > 0 and verbose:\n", + " print(\"missing keys:\")\n", + " print(m)\n", + " if len(u) > 0 and verbose:\n", + " print(\"unexpected keys:\")\n", + " print(u)\n", + "\n", + " if half_precision:\n", + " model = model.half().to(device)\n", + " else:\n", + " model = model.to(device)\n", + " model.eval()\n", + " return model\n", + "\n", + "if ckpt_valid:\n", + " local_config = OmegaConf.load(f\"{ckpt_config_path}\")\n", + " model = load_model_from_config(local_config, f\"{ckpt_path}\", half_precision=half_precision)\n", + " device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + " model = model.to(device)\n", + "\n", + " # Disable checkpointing as it is not compatible with the method\n", + " for module in model.modules():\n", + " if hasattr(module, 'checkpoint'):\n", + " module.checkpoint = False\n", + " if hasattr(module, 'use_checkpoint'):\n", + " module.use_checkpoint = False\n", + "\n", + " sd_model=model\n", + "\n", + "####################################################################\n", + "\n", + "use_new_vae = True #@param {type:\"boolean\"}\n", + "\n", + "if use_new_vae:\n", + "\n", + " # from kat's notebook again\n", + "\n", + " def download_from_huggingface(repo, filename):\n", + " while True:\n", + " try:\n", + " return huggingface_hub.hf_hub_download(repo, filename)\n", + " except HTTPError as e:\n", + " if e.response.status_code == 401:\n", + " # Need to log into huggingface api\n", + " huggingface_hub.interpreter_login()\n", + " continue\n", + " elif e.response.status_code == 403:\n", + " # Need to do the click through license thing\n", + " print(f'Go here and agree to the click through license on your account: https://huggingface.co/{repo}')\n", + " input('Hit enter when ready:')\n", + " continue\n", + " else:\n", + " raise e\n", + "\n", + " vae_840k_model_path = download_from_huggingface(\"stabilityai/sd-vae-ft-mse-original\", \"vae-ft-mse-840000-ema-pruned.ckpt\")\n", + "\n", + " def load_model_from_config_kc(config, ckpt):\n", + " print(f\"Loading model from {ckpt}\")\n", + " pl_sd = torch.load(ckpt, map_location=\"cpu\")\n", + " sd = pl_sd[\"state_dict\"]\n", + " config = OmegaConf.load(config)\n", + "\n", + " try:\n", + " config['model']['params']['lossconfig']['target'] = \"torch.nn.Identity\"\n", + " print('Patched VAE config.')\n", + " except KeyError:\n", + " pass\n", + "\n", + " model = instantiate_from_config(config.model)\n", + " m, u = model.load_state_dict(sd, strict=False)\n", + " model = model.to(cpu).eval().requires_grad_(False)\n", + " return model\n", + "\n", + " vaemodel_yaml_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/models/first_stage_models/kl-f8/config.yaml\"\n", + " vaemodel_yaml_fname = 'config_vae_kl-f8.yaml'\n", + " vaemodel_yaml_fname_git = \"latent-diffusion/models/first_stage_models/kl-f8/config.yaml\"\n", + " if Path(vaemodel_yaml_fname_git).exists():\n", + " vae_model = load_model_from_config_kc(vaemodel_yaml_fname_git, vae_840k_model_path).half().to(device)\n", + " else:\n", + " if not Path(vaemodel_yaml_fname).exists():\n", + " !wget {vaemodel_yaml_url} -O {vaemodel_yaml_fname}\n", + " vae_model = load_model_from_config_kc(vaemodel_yaml_fname, vae_840k_model_path).half().to(device)\n", + "\n", + " del sd_model.first_stage_model\n", + " sd_model.first_stage_model = vae_model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "ZljSF1ePnBl4", + "tags": [] + }, + "outputs": [], + "source": [ + "# @title Settings\n", + "\n", + "# @markdown The number of frames to sample:\n", + "n = 300 # @param {type:\"integer\"}\n", + "\n", + "# @markdown height and width must be multiples of 8 (e.g. 256, 512, 768, 1024)\n", + "height = 512 # @param {type:\"integer\"}\n", + "\n", + "width = 512 # @param {type:\"integer\"}\n", + "\n", + "\n", + "# @markdown If seed is negative, a random seed will be used\n", + "seed = -1 # @param {type:\"number\"}\n", + "\n", + "init_image = \"\" # @param {type:'string'}\n", + "\n", + "# @markdown ---\n", + "\n", + "# @markdown Settings below this line can be parameterized using keyframe syntax: `\"time:weight, time:weight, ...\". \n", + "# @markdown Over spans where values of weights change, intermediate values will be interpolated using an \"s\" shaped curve.\n", + "# @markdown If a value for keyframe 0 is not specified, it is presumed to be `0:0`.\n", + "\n", + "# @markdown The strength of the conditioning on the prompt:\n", + "g=\"0:0.1\" # @param {type:\"string\"}\n", + "\n", + "# @markdown The noise level to sample at\n", + "# @markdown Ramp up from a tiny sigma if using init image, e.g. `0:0.25, 100:2, ...`\n", + "# @markdown NB: Turning sigma *up* mid generation seems to work fine, but turning sigma *down* mid generation tends to \"deep fry\" the outputs\n", + "sigma = \"1.25\" # @param {type:\"string\"}\n", + "\n", + "# @markdown Step size (range 0 to 1):\n", + "h = \"0:0.1, 30:0.1, 50:0.3, 70:0.1, 120:0.1, 140:.3, 160:.1, 210:.1, 230:.3, 250:.1\" # @param {type:\"string\"}\n", + "\n", + "# @markdown Friction (2 is critically damped, lower -> smoother animation):\n", + "gamma = \"1.1\" # @param {type:\"string\"}\n", + "\n", + "# @markdown Quadratic penalty (\"weight decay\") strength:\n", + "alpha = \"0.005\" # @param {type:\"string\"}\n", + "\n", + "# @markdown Temperature (adjustment to the amount of noise added per step):\n", + "tau = \"1.0\" # @param {type:\"string\"}\n", + "\n", + "# @markdown Denoising refinement steps:\n", + "refinement_steps = \"6\" # @param {type:\"string\"}\n", + "\n", + "# @markdown If an init image is not provided, this is how many steps will be used when generating an initial state:\n", + "img_init_steps = 15 # @param {type:\"number\"}\n", + "\n", + "# @markdown The HVP method:\n", + "# @markdown
`forward-functorch` and `reverse` provide real second derivatives. Compatibility, speed, and memory usage vary by model and xformers configuration.\n", + "# @markdown `fake` is very fast and low memory but inaccurate. `zero` (fallback to first order KLMC) is not recommended.\n", + "hvp_method = 'fake' # @param [\"forward-functorch\", \"reverse\", \"fake\", \"zero\"]\n", + "\n", + "checkpoint_every = 10 # @param {type:\"number\"}\n", + "\n", + "###########################\n", + "\n", + "assert (height % 8) == 0\n", + "assert (width % 8) == 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1pLTsdGBPXx6", + "tags": [] + }, + "outputs": [], + "source": [ + "#@title Prompts\n", + "\n", + "# [ \n", + "# [\"first prompt will be used to initialize the image\", {time:weight, time:weight...}], \n", + "# [\"more prompts if you want\", {...}], \n", + "# ...]\n", + "\n", + "# if a weight for time=0 isn't specified, the weight is assumed to be zero.\n", + "# if you want to \"fade in\" any prompts, it's best to have them start with a small but non-zero value, e.g. 0.001\n", + "\n", + "prompt_params = [\n", + " # # FIRST PROMPT INITIALIZES IMAGE\n", + " #[\"sweetest puppy, golden retriever\", {0:.5, 30:0.5, 100:0.001}],\n", + " #[\"sweet old dog, golden retriever\", {0:0.001, 30:0.001, 100:0.5}],\n", + " #[\"happiest pupper, cutest dog evar, golden retriever, incredibly photogenic dog\", {0:1}],\n", + "\n", + " # # the 'flowers prompts' below go with a particular 'h' setting in the next cell\n", + " [\"incredibly beautiful orchids, a bouquet of orchids\", {0:1, 35:1, 50:0}],\n", + " [\"incredibly beautiful roses, a bouquet of roses\", {0:0.001, 35:0.001, 50:1, 120:1, 140:0}],\n", + " [\"incredibly beautiful carnations, a bouquet of carnations\", {0:0.001, 120:0.001, 140:1, 220:1, 240:0}],\n", + " [\"incredibly beautiful carnations, a bouquet of sunflowers\", {0:0.001, 220:0.001, 240:1}],\n", + " \n", + " # negative prompts\n", + " [\"watermark text\", {0:-0.1} ],\n", + " [\"jpeg artifacts\", {0:-0.1} ],\n", + " [\"artist's signature\", {0:-0.1} ],\n", + " [\"istockphoto, gettyimages, watermarked image\", {0:-0.1} ],\n", + "]\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "hZ0lh-WkdB19", + "tags": [] + }, + "outputs": [], + "source": [ + "# @title Build prompt and settings objects\n", + "\n", + "# @markdown some advanced features demonstrated in commented-out code in this cell\n", + "\n", + "plot_prompt_weight_curves = True # @param {type: 'boolean'}\n", + "plot_settings_weight_curves = True # @param {type: 'boolean'}\n", + "\n", + "#################\n", + "\n", + "# Build Prompt objects\n", + "\n", + "prompts = [\n", + " Prompt(text, weight_schedule) \n", + " for (text, weight_schedule) in prompt_params\n", + "]\n", + "\n", + "# uncomment to loop the prompts\n", + "#for p in prompts:\n", + "# if len(p.weight.keyframes) > 1: # ignore negative prompts\n", + "# p.weight.loop=True \n", + "\n", + "# uncomment to loop prompts in \"bounce\" mode\n", + "#for p in prompts:\n", + "# if len(p.weight.keyframes) > 1:\n", + "# p.weight.bounce=True \n", + "\n", + "#################\n", + "\n", + "# Build Settings object\n", + "\n", + "g = parse_curvable_string(g)\n", + "sigma = parse_curvable_string(sigma)\n", + "h = parse_curvable_string(h)\n", + "gamma = parse_curvable_string(gamma)\n", + "alpha = parse_curvable_string(alpha)\n", + "tau = parse_curvable_string(tau)\n", + "steps = parse_curvable_string(refinement_steps)\n", + "\n", + "\n", + "curved_settings = ParameterGroup({\n", + " 'g':SmoothCurve(g),\n", + " 'sigma':SmoothCurve(sigma),\n", + " #'h':SmoothCurve(h),\n", + " \n", + " # more concise notation for flowers demo:\n", + " 'h':SmoothCurve({0:0.1, 30:0.1, 50:0.3}, bounce=True),\n", + " #'h':SmoothCurve({0:0.1, 30:0.1, 50:0.3, 70:0.1, 90:0.1}, loop=True),\n", + "\n", + " 'gamma':SmoothCurve(gamma),\n", + " 'alpha':SmoothCurve(alpha),\n", + " 'tau':SmoothCurve(tau),\n", + " 'steps':SmoothCurve(steps),\n", + "})\n", + "\n", + "\n", + "if plot_prompt_weight_curves:\n", + " for prompt in prompts:\n", + " prompt.weight.plot(n=n)\n", + " plt.title(\"prompt weight schedules\")\n", + " plt.show()\n", + "\n", + "\n", + "if plot_settings_weight_curves:\n", + " for name, curve in curved_settings.parameters.items():\n", + " curve.plot(n=n)\n", + " plt.title(name)\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tthag9k67Uey", + "tags": [] + }, + "outputs": [], + "source": [ + "# @markdown running this cell saves the current settings to disk\n", + "\n", + "import keyframed.serialization\n", + "txt = keyframed.serialization.to_yaml(curved_settings, simplify=True)\n", + "\n", + "#print(txt)\n", + "\n", + "# sigma: 1.25\n", + "#\n", + "# becomes:\n", + "#\n", + "# sigma:\n", + "# curve:\n", + "# - - 0\n", + "# - 1.25\n", + "# - eased_lerp\n", + "#\n", + "# :\n", + "# curve:\n", + "# - - \n", + "# - \n", + "# - \n", + "# - \n", + "# - - \n", + "# - \n", + "# - - \n", + "# - \n", + "\n", + "with open(outdir / 'settings.yaml', 'w') as f:\n", + " f.write(txt)\n", + " \n", + "#########################\n", + "\n", + "# save prompts\n", + "\n", + "prompts_out = []\n", + "for prompt in prompts:\n", + " rec = {'prompt':prompt.text}\n", + " prompt.weight.loop=True\n", + " rec['schedule'] = prompt.weight.to_dict(simplify=True, for_yaml=True) #)\n", + " #rec.update( prompt.weight.to_dict(simplify=True, for_yaml=True) )\n", + " rec['prompt'] = prompt.text\n", + " prompts_out.append(rec)\n", + "\n", + "prompts_yaml = OmegaConf.to_yaml(OmegaConf.create({'prompts':prompts_out}))\n", + "\n", + "with open(outdir / 'prompts.yaml', 'w') as f:\n", + " f.write(prompts_yaml)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "srbY3kDa7Uey", + "tags": [] + }, + "outputs": [], + "source": [ + "# load settings from disk\n", + "\n", + "load_settings_from_disk = True # @param {type:'boolean'}\n", + "load_prompts_from_disk = True # @param {type:'boolean'}\n", + "\n", + "\n", + "if load_settings_from_disk:\n", + " with open(outdir / 'settings.yaml', 'r') as f:\n", + " curved_settings = keyframed.serialization.from_yaml(f.read())\n", + "\n", + "curved_settings.to_dict(simplify=True)['parameters']\n", + "#curved_settings.plot()\n", + "\n", + "###########################\n", + "\n", + "if load_prompts_from_disk:\n", + " with open(outdir / 'prompts.yaml', 'r') as f:\n", + " prompts_cfg = OmegaConf.load(f)\n", + " \n", + " prompts = []\n", + " for p in prompts_cfg.prompts:\n", + " weight_curve = keyframed.serialization.from_yaml(OmegaConf.to_yaml(p.schedule))\n", + " P = Prompt(text=p.prompt, weight_schedule=weight_curve)\n", + " prompts.append(P)\n", + " #P.weight.plot()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "i-_u1Q0wRqMb", + "tags": [] + }, + "outputs": [], + "source": [ + "#@title Generate Animation Frames\n", + "\n", + "resume = True # @param {type:'boolean'}\n", + "archive_old_work = False # @param {type:'boolean'}\n", + "\n", + "# -1 = most recent frame\n", + "resume_from = -1 # @param {type:'number'}\n", + "\n", + "# @markdown optional debugging plots\n", + "plot_prompt_weights = True # @param {type:'boolean'}\n", + "plot_h = False # @param {type:'boolean'}\n", + "plot_g = False # @param {type:'boolean'}\n", + "plot_sigma = False # @param {type:'boolean'}\n", + "plot_gamma = False # @param {type:'boolean'}\n", + "plot_alpha = False # @param {type:'boolean'}\n", + "plot_tau = False # @param {type:'boolean'}\n", + "\n", + "################\n", + "\n", + "_seed = seed\n", + "if seed < 0: \n", + " _seed = random.randrange(0, 4294967295)\n", + "print(f\"using seed: {_seed}\")\n", + "torch.manual_seed(_seed)\n", + "\n", + "stuff_to_plot = []\n", + "if plot_prompt_weights:\n", + " stuff_to_plot.append('prompts')\n", + "if plot_h:\n", + " stuff_to_plot.append('h')\n", + "if plot_g:\n", + " stuff_to_plot.append('g')\n", + "if plot_sigma:\n", + " stuff_to_plot.append('sigma')\n", + "if plot_gamma:\n", + " stuff_to_plot.append('gamma')\n", + "if plot_alpha:\n", + " stuff_to_plot.append('alpha')\n", + "if plot_tau:\n", + " stuff_to_plot.append('tau')\n", + "\n", + "if not resume:\n", + " if archive_old_work:\n", + " archive_dir = outdir.parent / 'archive' / str(int(time.time()))\n", + " archive_dir.mkdir(parents=True, exist_ok=True)\n", + " print(f\"Archiving contents of /frames, moving to: {archive_dir}\")\n", + " else:\n", + " print(\"Old contents of /frames being deleted. This can be prevented in the future by setting either 'resume' or 'archive_old_work' to True.\")\n", + " for p in outdir.glob(f'*'):\n", + " if archive_old_work:\n", + " target = archive_dir / p.name\n", + " p.rename(target)\n", + " else:\n", + " p.unlink()\n", + " for p in Path('debug_frames').glob(f'*'):\n", + " p.unlink()\n", + "\n", + "sample_mcmc_klmc2(\n", + " sd_model=sd_model,\n", + " init_image=init_image,\n", + " height=height,\n", + " width=width,\n", + " n=n,\n", + " hvp_method=hvp_method,\n", + " prompts=prompts,\n", + " settings=curved_settings,\n", + " resume=resume,\n", + " resume_from=resume_from,\n", + " img_init_steps=img_init_steps,\n", + " stuff_to_plot=stuff_to_plot,\n", + " checkpoint_every=checkpoint_every,\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "DjwY7XrooLX_" + }, + "outputs": [], + "source": [ + "#@title Make the video\n", + "\n", + "if 'width' not in locals():\n", + " width = height = 512\n", + "\n", + "\n", + "# @markdown If your video is larger than a few MB, attempting to embed it will probably crash\n", + "# @markdown the session. If this happens, view the generated video after downloading it first.\n", + "embed_video = True # @param {type:'boolean'}\n", + "download_video = False # @param {type:'boolean'}\n", + "\n", + "upscale_video = False # @param {type:'boolean'}\n", + "\n", + "\n", + "outdir_str = str(outdir)\n", + "\n", + "fps = 14 # @param {type:\"integer\"}\n", + "out_fname = \"out.mp4\" # @param {type: \"string\"}\n", + "\n", + "out_fullpath = str( outdir / out_fname )\n", + "print(f\"Video will be saved to: {out_fullpath}\")\n", + "\n", + "compile_video_cmd = f\"ffmpeg -y -r {fps} -i 'out_%*.png' -crf 15 -preset veryslow -pix_fmt yuv420p \" # {out_fname}\"\n", + "if upscale_video:\n", + " compile_video_cmd += f\"-vf scale={2*width}x{2*height}:flags=lanczos \"\n", + "compile_video_cmd += f\"{out_fname}\"\n", + "\n", + "print('\\nMaking the video...\\n')\n", + "!cd {outdir_str}; {compile_video_cmd}\n", + "\n", + "\n", + "debug=True\n", + "if debug:\n", + " #outdir_str = \"debug_frames\"\n", + " print(\"\\nMaking debug video...\")\n", + " #!cd debug_frames; ffmpeg -y -r {fps} -i 'prompts_out_%*.png' -crf 15 -preset veryslow -pix_fmt yuv420p debug_out.mp4\n", + " !cd {debug_dir}; ffmpeg -y -r {fps} -i 'debug_out_%*.png' -crf 15 -preset veryslow -pix_fmt yuv420p debug_out.mp4\n", + "\n", + "if embed_video:\n", + " print('\\nThe video:')\n", + " show_video(out_fullpath)\n", + " if debug:\n", + " show_video(debug_dir / \"debug_out.mp4\")\n", + "\n", + "if download_video and probably_using_colab:\n", + " from google.colab import files\n", + " files.download(out_fullpath)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rK_GlP_7WJiu", + "tags": [] + }, + "outputs": [], + "source": [ + "#@title Licensed under the MIT License { display-mode: \"form\" }\n", + "\n", + "# Copyright (c) 2022 Katherine Crowson \n", + "# Copyright (c) 2023 David Marx \n", + "# Copyright (c) 2022 deforum and contributors\n", + "\n", + "# Permission is hereby granted, free of charge, to any person obtaining a copy\n", + "# of this software and associated documentation files (the \"Software\"), to deal\n", + "# in the Software without restriction, including without limitation the rights\n", + "# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", + "# copies of the Software, and to permit persons to whom the Software is\n", + "# furnished to do so, subject to the following conditions:\n", + "\n", + "# The above copyright notice and this permission notice shall be included in\n", + "# all copies or substantial portions of the Software.\n", + "\n", + "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", + "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", + "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", + "# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", + "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", + "# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n", + "# THE SOFTWARE." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "private_outputs": true, + "provenance": [] + }, + "gpuClass": "premium", + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "vscode": { + "interpreter": { + "hash": "ff1624fd81a21ea709585fb1fdce5419f857f6a9e76cb1632f1b8b574978f9ee" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 6bc1ff525509ea5d9f9de7a854c9fb3e0cf5c80d Mon Sep 17 00:00:00 2001 From: David Marx Date: Thu, 2 Mar 2023 12:43:08 -0800 Subject: [PATCH 09/14] user-friendlier yaml --- Stable_Diffusion_KLMC2_Animation.ipynb | 81 ++++++++++++++++---------- 1 file changed, 49 insertions(+), 32 deletions(-) diff --git a/Stable_Diffusion_KLMC2_Animation.ipynb b/Stable_Diffusion_KLMC2_Animation.ipynb index a7c430f..41557a6 100644 --- a/Stable_Diffusion_KLMC2_Animation.ipynb +++ b/Stable_Diffusion_KLMC2_Animation.ipynb @@ -1357,30 +1357,21 @@ "# @markdown running this cell saves the current settings to disk\n", "\n", "import keyframed.serialization\n", - "txt = keyframed.serialization.to_yaml(curved_settings, simplify=True)\n", - "\n", - "#print(txt)\n", - "\n", - "# sigma: 1.25\n", - "#\n", - "# becomes:\n", - "#\n", - "# sigma:\n", - "# curve:\n", - "# - - 0\n", - "# - 1.25\n", - "# - eased_lerp\n", - "#\n", - "# :\n", - "# curve:\n", - "# - - \n", - "# - \n", - "# - \n", - "# - \n", - "# - - \n", - "# - \n", - "# - - \n", - "# - \n", + "# verbose:\n", + "#txt = keyframed.serialization.to_yaml(curved_settings, simplify=True)\n", + "\n", + "# significantly less verbose:\n", + "simplified_settings = {}\n", + "simplified_settings__curves = {}\n", + "for param, curve in curved_settings.parameters.items():\n", + " kf0 = curve._data[0]\n", + " if (len(curve.keyframes) == 1) and (kf0.interpolation_method in ['eased_lerp', 'previous','linear']):\n", + " simplified_settings[param] = kf0.value\n", + " else:\n", + " simplified_settings__curves[param] = curve.to_dict(simplify=True, for_yaml=True)\n", + "simplified_settings.update(simplified_settings__curves) # move verbose stuff to the bottom\n", + "\n", + "txt = OmegaConf.to_yaml(OmegaConf.create(simplified_settings))\n", "\n", "with open(outdir / 'settings.yaml', 'w') as f:\n", " f.write(txt)\n", @@ -1389,17 +1380,25 @@ "\n", "# save prompts\n", "\n", + "# verbose:\n", "prompts_out = []\n", "for prompt in prompts:\n", " rec = {'prompt':prompt.text}\n", - " prompt.weight.loop=True\n", - " rec['schedule'] = prompt.weight.to_dict(simplify=True, for_yaml=True) #)\n", - " #rec.update( prompt.weight.to_dict(simplify=True, for_yaml=True) )\n", - " rec['prompt'] = prompt.text\n", + " \n", + " # verbose:\n", + " #rec['schedule'] = prompt.weight.to_dict(simplify=True, for_yaml=True) #)\n", + " # less verbose:\n", + " curve = prompt.weight\n", + " kf0 = curve._data[0]\n", + " if (len(curve.keyframes) == 1) and (kf0.interpolation_method in ['eased_lerp', 'previous','linear']):\n", + " rec['weight'] = kf0.value\n", + " else:\n", + " rec['weight'] = curve.to_dict(simplify=True, for_yaml=True)\n", + " # don't reorder prompts. order matters, esp first prompt.\n", " prompts_out.append(rec)\n", "\n", "prompts_yaml = OmegaConf.to_yaml(OmegaConf.create({'prompts':prompts_out}))\n", - "\n", + "print(prompts_yaml)\n", "with open(outdir / 'prompts.yaml', 'w') as f:\n", " f.write(prompts_yaml)" ] @@ -1420,11 +1419,25 @@ "load_prompts_from_disk = True # @param {type:'boolean'}\n", "\n", "\n", + "from numbers import Number\n", + "\n", "if load_settings_from_disk:\n", " with open(outdir / 'settings.yaml', 'r') as f:\n", - " curved_settings = keyframed.serialization.from_yaml(f.read())\n", + " #curved_settings = keyframed.serialization.from_yaml(f.read())\n", + " simplified_settings_yaml = f.read()\n", + " simplified_settings_cfg = OmegaConf.to_container(OmegaConf.create(simplified_settings_yaml))\n", + "\n", + " rebuilt_settings = {}\n", + " for k,v in simplified_settings_cfg.items():\n", + " if isinstance(v, Number):\n", + " param = SmoothCurve(v)\n", + " else:\n", + " param = SmoothCurve(**v)\n", + " rebuilt_settings[k] = param\n", + " curved_settings = ParameterGroup(rebuilt_settings) \n", + " \n", "\n", - "curved_settings.to_dict(simplify=True)['parameters']\n", + "#curved_settings.to_dict(simplify=True)['parameters']\n", "#curved_settings.plot()\n", "\n", "###########################\n", @@ -1435,7 +1448,11 @@ " \n", " prompts = []\n", " for p in prompts_cfg.prompts:\n", - " weight_curve = keyframed.serialization.from_yaml(OmegaConf.to_yaml(p.schedule))\n", + " if isinstance(p.weight, Number):\n", + " weight_curve = SmoothCurve(p.weight)\n", + " else:\n", + " weight_curve = keyframed.serialization.from_yaml(OmegaConf.to_yaml(p.weight))\n", + " #weight_curve = keyframed.serialization.from_yaml(OmegaConf.to_yaml(p.schedule))\n", " P = Prompt(text=p.prompt, weight_schedule=weight_curve)\n", " prompts.append(P)\n", " #P.weight.plot()" From caaf5f5bada62a229838dd013d2f4566a51eeea3 Mon Sep 17 00:00:00 2001 From: David Marx Date: Thu, 2 Mar 2023 12:48:41 -0800 Subject: [PATCH 10/14] param weight drop label --- Stable_Diffusion_KLMC2_Animation.ipynb | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Stable_Diffusion_KLMC2_Animation.ipynb b/Stable_Diffusion_KLMC2_Animation.ipynb index 41557a6..1842c35 100644 --- a/Stable_Diffusion_KLMC2_Animation.ipynb +++ b/Stable_Diffusion_KLMC2_Animation.ipynb @@ -1368,7 +1368,9 @@ " if (len(curve.keyframes) == 1) and (kf0.interpolation_method in ['eased_lerp', 'previous','linear']):\n", " simplified_settings[param] = kf0.value\n", " else:\n", - " simplified_settings__curves[param] = curve.to_dict(simplify=True, for_yaml=True)\n", + " d_ = curve.to_dict(simplify=True, for_yaml=True)\n", + " d_.pop('label')\n", + " simplified_settings__curves[param] = d_\n", "simplified_settings.update(simplified_settings__curves) # move verbose stuff to the bottom\n", "\n", "txt = OmegaConf.to_yaml(OmegaConf.create(simplified_settings))\n", @@ -1398,7 +1400,7 @@ " prompts_out.append(rec)\n", "\n", "prompts_yaml = OmegaConf.to_yaml(OmegaConf.create({'prompts':prompts_out}))\n", - "print(prompts_yaml)\n", + "\n", "with open(outdir / 'prompts.yaml', 'w') as f:\n", " f.write(prompts_yaml)" ] From 54d40b2c552beb8c7a33a6b1a1981a82a86b1362 Mon Sep 17 00:00:00 2001 From: David Marx Date: Thu, 2 Mar 2023 22:10:02 -0800 Subject: [PATCH 11/14] simpler settings load w keyframed==0.3.9 --- Stable_Diffusion_KLMC2_Animation.ipynb | 32 +++++++++++++------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/Stable_Diffusion_KLMC2_Animation.ipynb b/Stable_Diffusion_KLMC2_Animation.ipynb index 1842c35..01d1a7f 100644 --- a/Stable_Diffusion_KLMC2_Animation.ipynb +++ b/Stable_Diffusion_KLMC2_Animation.ipynb @@ -1373,7 +1373,7 @@ " simplified_settings__curves[param] = d_\n", "simplified_settings.update(simplified_settings__curves) # move verbose stuff to the bottom\n", "\n", - "txt = OmegaConf.to_yaml(OmegaConf.create(simplified_settings))\n", + "txt = OmegaConf.to_yaml(OmegaConf.create({'parameters':simplified_settings}))\n", "\n", "with open(outdir / 'settings.yaml', 'w') as f:\n", " f.write(txt)\n", @@ -1386,7 +1386,7 @@ "prompts_out = []\n", "for prompt in prompts:\n", " rec = {'prompt':prompt.text}\n", - " \n", + "\n", " # verbose:\n", " #rec['schedule'] = prompt.weight.to_dict(simplify=True, for_yaml=True) #)\n", " # less verbose:\n", @@ -1425,19 +1425,19 @@ "\n", "if load_settings_from_disk:\n", " with open(outdir / 'settings.yaml', 'r') as f:\n", - " #curved_settings = keyframed.serialization.from_yaml(f.read())\n", - " simplified_settings_yaml = f.read()\n", - " simplified_settings_cfg = OmegaConf.to_container(OmegaConf.create(simplified_settings_yaml))\n", - "\n", - " rebuilt_settings = {}\n", - " for k,v in simplified_settings_cfg.items():\n", - " if isinstance(v, Number):\n", - " param = SmoothCurve(v)\n", - " else:\n", - " param = SmoothCurve(**v)\n", - " rebuilt_settings[k] = param\n", - " curved_settings = ParameterGroup(rebuilt_settings) \n", - " \n", + " curved_settings = keyframed.serialization.from_yaml(f.read())\n", + "# simplified_settings_yaml = f.read()\n", + "# simplified_settings_cfg = OmegaConf.to_container(OmegaConf.create(simplified_settings_yaml))\n", + "\n", + "# rebuilt_settings = {}\n", + "# for k,v in simplified_settings_cfg.items():\n", + "# if isinstance(v, Number):\n", + "# param = SmoothCurve(v)\n", + "# else:\n", + "# param = SmoothCurve(**v)\n", + "# rebuilt_settings[k] = param\n", + "# curved_settings = ParameterGroup(rebuilt_settings)\n", + "\n", "\n", "#curved_settings.to_dict(simplify=True)['parameters']\n", "#curved_settings.plot()\n", @@ -1447,7 +1447,7 @@ "if load_prompts_from_disk:\n", " with open(outdir / 'prompts.yaml', 'r') as f:\n", " prompts_cfg = OmegaConf.load(f)\n", - " \n", + "\n", " prompts = []\n", " for p in prompts_cfg.prompts:\n", " if isinstance(p.weight, Number):\n", From f76c6290c7c522ff0a7f77f35a88657700bb5948 Mon Sep 17 00:00:00 2001 From: David Marx Date: Thu, 2 Mar 2023 22:16:54 -0800 Subject: [PATCH 12/14] cleanup --- Stable_Diffusion_KLMC2_Animation.ipynb | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/Stable_Diffusion_KLMC2_Animation.ipynb b/Stable_Diffusion_KLMC2_Animation.ipynb index 01d1a7f..660be01 100644 --- a/Stable_Diffusion_KLMC2_Animation.ipynb +++ b/Stable_Diffusion_KLMC2_Animation.ipynb @@ -1377,19 +1377,14 @@ "\n", "with open(outdir / 'settings.yaml', 'w') as f:\n", " f.write(txt)\n", - " \n", + "\n", "#########################\n", "\n", "# save prompts\n", "\n", - "# verbose:\n", "prompts_out = []\n", "for prompt in prompts:\n", " rec = {'prompt':prompt.text}\n", - "\n", - " # verbose:\n", - " #rec['schedule'] = prompt.weight.to_dict(simplify=True, for_yaml=True) #)\n", - " # less verbose:\n", " curve = prompt.weight\n", " kf0 = curve._data[0]\n", " if (len(curve.keyframes) == 1) and (kf0.interpolation_method in ['eased_lerp', 'previous','linear']):\n", @@ -1426,20 +1421,7 @@ "if load_settings_from_disk:\n", " with open(outdir / 'settings.yaml', 'r') as f:\n", " curved_settings = keyframed.serialization.from_yaml(f.read())\n", - "# simplified_settings_yaml = f.read()\n", - "# simplified_settings_cfg = OmegaConf.to_container(OmegaConf.create(simplified_settings_yaml))\n", - "\n", - "# rebuilt_settings = {}\n", - "# for k,v in simplified_settings_cfg.items():\n", - "# if isinstance(v, Number):\n", - "# param = SmoothCurve(v)\n", - "# else:\n", - "# param = SmoothCurve(**v)\n", - "# rebuilt_settings[k] = param\n", - "# curved_settings = ParameterGroup(rebuilt_settings)\n", - "\n", "\n", - "#curved_settings.to_dict(simplify=True)['parameters']\n", "#curved_settings.plot()\n", "\n", "###########################\n", @@ -1454,7 +1436,6 @@ " weight_curve = SmoothCurve(p.weight)\n", " else:\n", " weight_curve = keyframed.serialization.from_yaml(OmegaConf.to_yaml(p.weight))\n", - " #weight_curve = keyframed.serialization.from_yaml(OmegaConf.to_yaml(p.schedule))\n", " P = Prompt(text=p.prompt, weight_schedule=weight_curve)\n", " prompts.append(P)\n", " #P.weight.plot()" From a895f076ff2f6bde630b069a88dbd05b21499b62 Mon Sep 17 00:00:00 2001 From: David Marx Date: Thu, 2 Mar 2023 22:47:57 -0800 Subject: [PATCH 13/14] moved settings write to not conflict w /frames flush --- Stable_Diffusion_KLMC2_Animation.ipynb | 112 ++++++++++++++----------- 1 file changed, 65 insertions(+), 47 deletions(-) diff --git a/Stable_Diffusion_KLMC2_Animation.ipynb b/Stable_Diffusion_KLMC2_Animation.ipynb index 660be01..019bcf0 100644 --- a/Stable_Diffusion_KLMC2_Animation.ipynb +++ b/Stable_Diffusion_KLMC2_Animation.ipynb @@ -1353,52 +1353,7 @@ "tags": [] }, "outputs": [], - "source": [ - "# @markdown running this cell saves the current settings to disk\n", - "\n", - "import keyframed.serialization\n", - "# verbose:\n", - "#txt = keyframed.serialization.to_yaml(curved_settings, simplify=True)\n", - "\n", - "# significantly less verbose:\n", - "simplified_settings = {}\n", - "simplified_settings__curves = {}\n", - "for param, curve in curved_settings.parameters.items():\n", - " kf0 = curve._data[0]\n", - " if (len(curve.keyframes) == 1) and (kf0.interpolation_method in ['eased_lerp', 'previous','linear']):\n", - " simplified_settings[param] = kf0.value\n", - " else:\n", - " d_ = curve.to_dict(simplify=True, for_yaml=True)\n", - " d_.pop('label')\n", - " simplified_settings__curves[param] = d_\n", - "simplified_settings.update(simplified_settings__curves) # move verbose stuff to the bottom\n", - "\n", - "txt = OmegaConf.to_yaml(OmegaConf.create({'parameters':simplified_settings}))\n", - "\n", - "with open(outdir / 'settings.yaml', 'w') as f:\n", - " f.write(txt)\n", - "\n", - "#########################\n", - "\n", - "# save prompts\n", - "\n", - "prompts_out = []\n", - "for prompt in prompts:\n", - " rec = {'prompt':prompt.text}\n", - " curve = prompt.weight\n", - " kf0 = curve._data[0]\n", - " if (len(curve.keyframes) == 1) and (kf0.interpolation_method in ['eased_lerp', 'previous','linear']):\n", - " rec['weight'] = kf0.value\n", - " else:\n", - " rec['weight'] = curve.to_dict(simplify=True, for_yaml=True)\n", - " # don't reorder prompts. order matters, esp first prompt.\n", - " prompts_out.append(rec)\n", - "\n", - "prompts_yaml = OmegaConf.to_yaml(OmegaConf.create({'prompts':prompts_out}))\n", - "\n", - "with open(outdir / 'prompts.yaml', 'w') as f:\n", - " f.write(prompts_yaml)" - ] + "source": [] }, { "cell_type": "code", @@ -1508,6 +1463,57 @@ " for p in Path('debug_frames').glob(f'*'):\n", " p.unlink()\n", "\n", + "\n", + "#############################################\n", + "\n", + "# save settings\n", + "\n", + "import keyframed.serialization\n", + "# verbose:\n", + "#txt = keyframed.serialization.to_yaml(curved_settings, simplify=True)\n", + "\n", + "# significantly less verbose:\n", + "simplified_settings = {}\n", + "simplified_settings__curves = {}\n", + "for param, curve in curved_settings.parameters.items():\n", + " kf0 = curve._data[0]\n", + " if (len(curve.keyframes) == 1) and (kf0.interpolation_method in ['eased_lerp', 'previous','linear']):\n", + " simplified_settings[param] = kf0.value\n", + " else:\n", + " d_ = curve.to_dict(simplify=True, for_yaml=True)\n", + " d_.pop('label')\n", + " simplified_settings__curves[param] = d_\n", + "simplified_settings.update(simplified_settings__curves) # move verbose stuff to the bottom\n", + "\n", + "txt = OmegaConf.to_yaml(OmegaConf.create({'parameters':simplified_settings}))\n", + "\n", + "with open(outdir / 'settings.yaml', 'w') as f:\n", + " f.write(txt)\n", + "\n", + "#########################\n", + "\n", + "# save prompts\n", + "\n", + "prompts_out = []\n", + "for prompt in prompts:\n", + " rec = {'prompt':prompt.text}\n", + " curve = prompt.weight\n", + " kf0 = curve._data[0]\n", + " if (len(curve.keyframes) == 1) and (kf0.interpolation_method in ['eased_lerp', 'previous','linear']):\n", + " rec['weight'] = kf0.value\n", + " else:\n", + " rec['weight'] = curve.to_dict(simplify=True, for_yaml=True)\n", + " # don't reorder prompts. order matters, esp first prompt.\n", + " prompts_out.append(rec)\n", + "\n", + "prompts_yaml = OmegaConf.to_yaml(OmegaConf.create({'prompts':prompts_out}))\n", + "\n", + "with open(outdir / 'prompts.yaml', 'w') as f:\n", + " f.write(prompts_yaml)\n", + "\n", + "##########################\n", + "\n", + "\n", "sample_mcmc_klmc2(\n", " sd_model=sd_model,\n", " init_image=init_image,\n", @@ -1525,12 +1531,24 @@ ")\n" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!rm -rf frames/.ipynb_checkpoints" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", - "id": "DjwY7XrooLX_" + "id": "DjwY7XrooLX_", + "tags": [] }, "outputs": [], "source": [ From ae344a123a4b5df89189e4af1ec7c0da5e4486fc Mon Sep 17 00:00:00 2001 From: David Marx Date: Thu, 2 Mar 2023 22:59:00 -0800 Subject: [PATCH 14/14] fixed resume --- Stable_Diffusion_KLMC2_Animation.ipynb | 1 + 1 file changed, 1 insertion(+) diff --git a/Stable_Diffusion_KLMC2_Animation.ipynb b/Stable_Diffusion_KLMC2_Animation.ipynb index 019bcf0..8084616 100644 --- a/Stable_Diffusion_KLMC2_Animation.ipynb +++ b/Stable_Diffusion_KLMC2_Animation.ipynb @@ -636,6 +636,7 @@ " x, v, i_resume = state['x'], state['v'], state['i']\n", " # to do: resumption of settings\n", " settings_i = state['settings_i']\n", + " i = i_resume\n", " settings[i]['h'] = settings_i['h']\n", " settings[i]['gamma'] = settings_i['gamma']\n", " settings[i]['alpha'] = settings_i['alpha']\n",