{ "cells": [ { "cell_type": "markdown", "id": "530feebe", "metadata": {}, "source": [ "`LatentPotentialSDE.reference.ipynb` \n", "---" ] }, { "cell_type": "code", "execution_count": 1, "id": "bd0a3fa6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Version: 0.3.2\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 1;\n", " var nbb_unformatted_code = \"%load_ext nb_black\\n# pip install neural-diffeqs\\n\\nimport neural_diffeqs\\n\\nprint(f\\\"Version: {neural_diffeqs.__version__}\\\")\\nimport torch\";\n", " var nbb_formatted_code = \"%load_ext nb_black\\n# pip install neural-diffeqs\\n\\nimport neural_diffeqs\\n\\nprint(f\\\"Version: {neural_diffeqs.__version__}\\\")\\nimport torch\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%load_ext nb_black\n", "# pip install neural-diffeqs\n", "\n", "import neural_diffeqs\n", "\n", "print(f\"Version: {neural_diffeqs.__version__}\")\n", "import torch" ] }, { "cell_type": "code", "execution_count": 2, "id": "5b650fd1", "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LatentPotentialSDE(\n", " (mu): TorchNet(\n", " (hidden_1): Sequential(\n", " (linear): Linear(in_features=20, out_features=512, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): LeakyReLU(negative_slope=0.01)\n", " )\n", " (hidden_2): Sequential(\n", " (linear): Linear(in_features=512, out_features=512, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): LeakyReLU(negative_slope=0.01)\n", " )\n", " (output): Sequential(\n", " (linear): Linear(in_features=512, out_features=20, bias=False)\n", " )\n", " )\n", " (sigma): TorchNet(\n", " (hidden_1): Sequential(\n", " (linear): Linear(in_features=20, out_features=32, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): LeakyReLU(negative_slope=0.01)\n", " )\n", " (hidden_2): Sequential(\n", " (linear): Linear(in_features=32, out_features=32, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): LeakyReLU(negative_slope=0.01)\n", " )\n", " (output): Sequential(\n", " (linear): Linear(in_features=32, out_features=20, bias=True)\n", " )\n", " )\n", " (potential): Potential(\n", " (psi): Linear(in_features=20, out_features=1, bias=False)\n", " )\n", ")\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 2;\n", " var nbb_unformatted_code = \"SDE = neural_diffeqs.LatentPotentialSDE(\\n state_size=20, mu_hidden=[512, 512], sigma_hidden=[32, 32]\\n)\\nprint(SDE)\";\n", " var nbb_formatted_code = \"SDE = neural_diffeqs.LatentPotentialSDE(\\n state_size=20, mu_hidden=[512, 512], sigma_hidden=[32, 32]\\n)\\nprint(SDE)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "SDE = neural_diffeqs.LatentPotentialSDE(\n", " state_size=20, mu_hidden=[512, 512], sigma_hidden=[32, 32]\n", ")\n", "print(SDE)" ] }, { "cell_type": "markdown", "id": "62b0ead9", "metadata": {}, "source": [ "Notice that the output layer of the mu function (drift network) contains only a single feature, without bias (by default):\n", "\n", ">```\n", ">(output): Sequential(\n", "> (linear): Linear(in_features=512, out_features=1, bias=False)\n", ">)\n", ">```\n", "\n", "This is, so far identical to the `neural_diffeqs.PotentialSDE`. The key difference is the introduction of the `h()` function." ] }, { "cell_type": "markdown", "id": "b1ca78ca", "metadata": {}, "source": [ "### `LatentPotentialSDE.h(y)`" ] }, { "cell_type": "code", "execution_count": 3, "id": "4ddc9ae6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-0.1489, -0.8965, -0.4265, -0.0423, -0.2972, -0.6784, 0.2970, -0.1280,\n", " 1.0155, -0.0928, 2.4592, 0.4539, -0.6524, 0.0763, 1.9708, 0.7134,\n", " 0.2520, 0.9877, 0.0877, 1.1103],\n", " [-0.2822, 0.6845, -0.1269, -2.4295, 1.7630, 0.4437, -1.7529, -1.2660,\n", " 0.1211, -0.1806, -0.5824, 0.6802, -0.9001, 1.1840, 0.6217, -0.6073,\n", " -0.6521, -0.1147, -0.5269, 0.8745],\n", " [-1.2511, 1.0869, -1.1027, -0.6977, 0.3685, 0.2092, -0.2874, -0.3324,\n", " -1.3541, -0.9840, -0.8346, -0.4875, -1.7272, 0.7593, 1.0221, 0.6649,\n", " -0.0999, -0.2326, -0.1195, 0.8868],\n", " [-1.0061, -0.0645, -0.6290, 0.4299, -0.7686, 1.6165, 0.2057, -0.4136,\n", " -1.9558, 0.4749, -1.5010, 0.5263, 0.0259, -0.3909, 0.3455, 0.9469,\n", " -0.9320, 1.4909, -1.5104, 1.1154],\n", " [ 0.4129, 0.0150, 0.3230, -0.0096, -0.4123, -1.0741, -0.3676, 0.2829,\n", " -0.4771, -0.5568, 1.7971, -0.5946, 0.3360, -0.7315, 0.5277, -0.8717,\n", " 0.9425, -0.1722, -0.1044, -0.2779]])\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 3;\n", " var nbb_unformatted_code = \"# 5 samples x 20 dim\\ny = torch.randn([5, 20])\\nprint(y)\";\n", " var nbb_formatted_code = \"# 5 samples x 20 dim\\ny = torch.randn([5, 20])\\nprint(y)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# 5 samples x 20 dim\n", "y = torch.randn([5, 20])\n", "print(y)" ] }, { "cell_type": "code", "execution_count": 4, "id": "0485b578", "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 4;\n", " var nbb_unformatted_code = \"# output of a function without assumption\\nf_out = SDE.f(None, y)\\n\\n# output of a prior-regularized function\\nh_out = SDE.h(None, y)\";\n", " var nbb_formatted_code = \"# output of a function without assumption\\nf_out = SDE.f(None, y)\\n\\n# output of a prior-regularized function\\nh_out = SDE.h(None, y)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# output of a function without assumption\n", "f_out = SDE.f(None, y)\n", "\n", "# output of a prior-regularized function\n", "h_out = SDE.h(None, y)" ] }, { "cell_type": "code", "execution_count": 5, "id": "c8675963", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "f_out shape: torch.Size([5, 20]), h_out shape: torch.Size([5, 20])\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 5;\n", " var nbb_unformatted_code = \"print(f\\\"f_out shape: {f_out.shape}, h_out shape: {h_out.shape}\\\")\";\n", " var nbb_formatted_code = \"print(f\\\"f_out shape: {f_out.shape}, h_out shape: {h_out.shape}\\\")\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print(f\"f_out shape: {f_out.shape}, h_out shape: {h_out.shape}\")" ] }, { "cell_type": "markdown", "id": "d62ff0fe", "metadata": {}, "source": [ "`torchsde.sdeint` has built-in functionality to compute the KL-divergence at each predicted state and is returned alongside the predicted state:" ] }, { "cell_type": "code", "execution_count": 6, "id": "009b3609", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([3, 5, 20]) torch.Size([2, 5])\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 6;\n", " var nbb_unformatted_code = \"import torchsde\\n\\npred, kl_div = torchsde.sdeint(SDE, y, ts=torch.linspace(0, 0.1, 3), logqp=True)\\nprint(pred.shape, kl_div.shape)\";\n", " var nbb_formatted_code = \"import torchsde\\n\\npred, kl_div = torchsde.sdeint(SDE, y, ts=torch.linspace(0, 0.1, 3), logqp=True)\\nprint(pred.shape, kl_div.shape)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import torchsde\n", "\n", "pred, kl_div = torchsde.sdeint(SDE, y, ts=torch.linspace(0, 0.1, 3), logqp=True)\n", "print(pred.shape, kl_div.shape)" ] }, { "cell_type": "markdown", "id": "3b5c4f03", "metadata": {}, "source": [ "For more about the `PotentialSDE` and how the potential function works, see the `PotentialSDE` notebook." ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:sdq]", "language": "python", "name": "conda-env-sdq-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 }