{ "cells": [ { "cell_type": "markdown", "id": "530feebe", "metadata": {}, "source": [ "`LatentPotentialODE.reference.ipynb` \n", "---" ] }, { "cell_type": "code", "execution_count": 1, "id": "bd0a3fa6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Version: 0.3.2\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 1;\n", " var nbb_unformatted_code = \"%load_ext nb_black\\n# pip install neural-diffeqs\\n\\nimport neural_diffeqs\\n\\nprint(f\\\"Version: {neural_diffeqs.__version__}\\\")\\nimport torch\";\n", " var nbb_formatted_code = \"%load_ext nb_black\\n# pip install neural-diffeqs\\n\\nimport neural_diffeqs\\n\\nprint(f\\\"Version: {neural_diffeqs.__version__}\\\")\\nimport torch\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# %load_ext nb_black\n", "# pip install neural-diffeqs\n", "\n", "import neural_diffeqs\n", "\n", "print(f\"Version: {neural_diffeqs.__version__}\")\n", "import torch" ] }, { "cell_type": "code", "execution_count": 2, "id": "5b650fd1", "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LatentPotentialODE(\n", " (mu): TorchNet(\n", " (hidden_1): Sequential(\n", " (linear): Linear(in_features=20, out_features=32, bias=True)\n", " (activation): LeakyReLU(negative_slope=0.01)\n", " )\n", " (hidden_2): Sequential(\n", " (linear): Linear(in_features=32, out_features=32, bias=True)\n", " (activation): LeakyReLU(negative_slope=0.01)\n", " )\n", " (output): Sequential(\n", " (linear): Linear(in_features=32, out_features=20, bias=True)\n", " )\n", " )\n", " (potential): Potential(\n", " (psi): Linear(in_features=20, out_features=1, bias=False)\n", " )\n", ")\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 2;\n", " var nbb_unformatted_code = \"ODE = neural_diffeqs.LatentPotentialODE(state_size=20, mu_hidden=[32, 32])\\nprint(ODE)\";\n", " var nbb_formatted_code = \"ODE = neural_diffeqs.LatentPotentialODE(state_size=20, mu_hidden=[32, 32])\\nprint(ODE)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ODE = neural_diffeqs.LatentPotentialODE(state_size=20, mu_hidden=[32, 32])\n", "print(ODE)" ] }, { "cell_type": "markdown", "id": "62b0ead9", "metadata": {}, "source": [ "Notice that the output layer of the mu function (drift network) contains only a single feature, without bias (by default):\n", "\n", ">```\n", ">(output): Sequential(\n", "> (linear): Linear(in_features=512, out_features=1, bias=False)\n", ">)\n", ">```\n", "\n", "This is, so far identical to the `neural_diffeqs.PotentialSDE`. The key difference is the introduction of the `h()` function." ] }, { "cell_type": "markdown", "id": "b1ca78ca", "metadata": {}, "source": [ "### `LatentPotentialSDE.h(y)`" ] }, { "cell_type": "code", "execution_count": 3, "id": "4ddc9ae6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-0.8237, 1.0351, 1.7423, -1.5425, 0.4430, 0.8470, -1.4426, 0.4999,\n", " 1.5510, -0.2683, -1.7453, 0.9307, 1.3925, 1.5210, 1.6499, 0.7581,\n", " -0.7673, -0.3574, 0.6198, -1.3364],\n", " [-0.8484, -0.0280, -0.9399, -1.2865, 0.0272, 1.3988, -0.8500, 0.0987,\n", " -0.0119, 0.1664, 1.4062, -0.9316, -1.2265, -0.4161, 0.3649, 0.1787,\n", " 0.4278, 0.7464, -1.7572, -0.1976],\n", " [-0.5799, 0.2325, 0.0680, 0.8387, -0.1876, -2.5092, 0.0367, -0.0986,\n", " -0.7004, -1.6255, -0.3075, -0.4883, -0.2802, 0.5861, -0.1205, 0.2164,\n", " -0.9112, 0.1279, 1.9563, 0.5955],\n", " [-0.5629, 0.3056, -0.8121, -1.9160, -0.9585, 0.1542, -0.7999, 0.2500,\n", " -0.8330, -1.6756, -0.3278, 2.4593, -0.6612, -0.9197, 0.0237, -1.4666,\n", " 0.9548, 0.7736, 0.8984, -1.3243],\n", " [-0.0242, -1.0799, -0.3576, -0.3413, -0.6483, -1.3423, -0.9658, 0.8313,\n", " 0.3130, -0.5277, -1.0104, 0.3941, 0.2500, -0.1888, 0.3429, -0.9617,\n", " -1.5048, -0.4335, -0.3003, 0.8630]])\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 3;\n", " var nbb_unformatted_code = \"# 5 samples x 20 dim\\ny = torch.randn([5, 20])\\nprint(y)\";\n", " var nbb_formatted_code = \"# 5 samples x 20 dim\\ny = torch.randn([5, 20])\\nprint(y)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# 5 samples x 20 dim\n", "y = torch.randn([5, 20])\n", "print(y)" ] }, { "cell_type": "code", "execution_count": 4, "id": "0485b578", "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 4;\n", " var nbb_unformatted_code = \"# output of a function without assumption\\nf_out = ODE.f(None, y)\\n\\n# output of a prior-regularized function\\nh_out = ODE.h(None, y)\";\n", " var nbb_formatted_code = \"# output of a function without assumption\\nf_out = ODE.f(None, y)\\n\\n# output of a prior-regularized function\\nh_out = ODE.h(None, y)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# output of a function without assumption\n", "f_out = ODE.f(None, y)\n", "\n", "# output of a prior-regularized function\n", "h_out = ODE.h(None, y)" ] }, { "cell_type": "code", "execution_count": 5, "id": "c8675963", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "f_out shape: torch.Size([5, 20]), h_out shape: torch.Size([5, 20])\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 5;\n", " var nbb_unformatted_code = \"print(f\\\"f_out shape: {f_out.shape}, h_out shape: {h_out.shape}\\\")\";\n", " var nbb_formatted_code = \"print(f\\\"f_out shape: {f_out.shape}, h_out shape: {h_out.shape}\\\")\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print(f\"f_out shape: {f_out.shape}, h_out shape: {h_out.shape}\")" ] }, { "cell_type": "markdown", "id": "d62ff0fe", "metadata": {}, "source": [ "`torchsde.sdeint` has built-in functionality to compute the KL-divergence at each predicted state and is returned alongside the predicted state:" ] }, { "cell_type": "code", "execution_count": 6, "id": "009b3609", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([3, 5, 20]) torch.Size([2, 5])\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 6;\n", " var nbb_unformatted_code = \"import torchsde\\n\\npred, kl_div = torchsde.sdeint(ODE, y, ts=torch.linspace(0, 0.1, 3), logqp=True)\\nprint(pred.shape, kl_div.shape)\";\n", " var nbb_formatted_code = \"import torchsde\\n\\npred, kl_div = torchsde.sdeint(ODE, y, ts=torch.linspace(0, 0.1, 3), logqp=True)\\nprint(pred.shape, kl_div.shape)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import torchsde\n", "\n", "pred, kl_div = torchsde.sdeint(ODE, y, ts=torch.linspace(0, 0.1, 3), logqp=True)\n", "print(pred.shape, kl_div.shape)" ] }, { "cell_type": "markdown", "id": "3b5c4f03", "metadata": {}, "source": [ "For more about the `PotentialSDE` and how the potential function works, see the `PotentialSDE` notebook." ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:sdq]", "language": "python", "name": "conda-env-sdq-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 }