{ "cells": [ { "cell_type": "markdown", "id": "530feebe", "metadata": {}, "source": [ "`PotentialODE.reference.ipynb` \n", "---" ] }, { "cell_type": "code", "execution_count": 1, "id": "bd0a3fa6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Version: 0.3.2\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 1;\n", " var nbb_unformatted_code = \"%load_ext nb_black\\n# pip install neural-diffeqs\\n\\nimport neural_diffeqs\\n\\nprint(f\\\"Version: {neural_diffeqs.__version__}\\\")\\nimport torch\";\n", " var nbb_formatted_code = \"%load_ext nb_black\\n# pip install neural-diffeqs\\n\\nimport neural_diffeqs\\n\\nprint(f\\\"Version: {neural_diffeqs.__version__}\\\")\\nimport torch\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%load_ext nb_black\n", "# pip install neural-diffeqs\n", "\n", "import neural_diffeqs\n", "\n", "print(f\"Version: {neural_diffeqs.__version__}\")\n", "import torch" ] }, { "cell_type": "markdown", "id": "763caa74", "metadata": {}, "source": [ "### Default `PotentialODE`\n", "\n", "As in the `NeuralODE` and `NeuralSDE`, the only required parameter is:\n", "\n", "* `state_size`" ] }, { "cell_type": "code", "execution_count": 2, "id": "5b650fd1", "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PotentialODE(\n", " (mu): TorchNet(\n", " (hidden_1): Sequential(\n", " (linear): Linear(in_features=20, out_features=64, bias=True)\n", " (activation): LeakyReLU(negative_slope=0.01)\n", " )\n", " (hidden_2): Sequential(\n", " (linear): Linear(in_features=64, out_features=64, bias=True)\n", " (activation): LeakyReLU(negative_slope=0.01)\n", " )\n", " (output): Sequential(\n", " (linear): Linear(in_features=64, out_features=1, bias=False)\n", " )\n", " )\n", ")\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 2;\n", " var nbb_unformatted_code = \"ODE = neural_diffeqs.PotentialODE(state_size=20, mu_hidden=[64, 64])\\nprint(ODE)\";\n", " var nbb_formatted_code = \"ODE = neural_diffeqs.PotentialODE(state_size=20, mu_hidden=[64, 64])\\nprint(ODE)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ODE = neural_diffeqs.PotentialODE(state_size=20, mu_hidden=[64, 64])\n", "print(ODE)" ] }, { "cell_type": "markdown", "id": "62b0ead9", "metadata": {}, "source": [ "Notice that, the output layer of the mu function (drift network) contains only a single feature, without bias (by default):\n", "\n", ">```\n", "(output): Sequential(\n", " (linear): Linear(in_features=512, out_features=1, bias=False)\n", ")\n", "```" ] }, { "cell_type": "code", "execution_count": 3, "id": "4ddc9ae6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-0.3228, 0.1208, -0.4852, -1.5084, -1.5721, -0.3639, 0.7440, -0.3686,\n", " 0.0311, 0.0771, -0.2299, 0.6744, 1.0946, 1.8330, -0.0296, 0.0648,\n", " 1.1006, -0.2764, 1.1437, 2.3773],\n", " [-0.3958, 0.5138, 0.0432, -0.0932, 2.2951, -1.3561, 0.4870, -1.4136,\n", " 1.0158, -1.5997, -2.5283, 0.5416, -0.0307, -1.3237, -0.1767, 1.9054,\n", " 1.5694, -0.5196, 0.3701, 1.1737],\n", " [ 0.4174, 0.5138, -1.1092, -0.0051, -1.0940, -0.1167, 1.0324, -1.6913,\n", " -0.0920, -0.4754, 0.2900, -0.9673, -2.0961, -0.1045, 1.0242, -0.9995,\n", " 1.1757, -0.4290, 1.8097, 0.9739],\n", " [ 1.1494, 0.0153, 0.2669, -0.2239, -0.3777, 0.6096, 0.1383, -0.6071,\n", " -0.9103, 0.7419, -0.2132, -0.0099, -1.1543, -0.0944, 1.1588, 0.7124,\n", " -2.2216, -0.0356, 0.0506, 0.1203],\n", " [ 0.9027, -0.2452, 0.3295, 1.3359, 0.5843, 0.5813, 0.2054, 0.1454,\n", " 1.4566, 0.4892, -0.5222, -0.6388, 0.4159, 0.8533, -0.8192, -1.3142,\n", " 1.6547, 0.9800, 0.4647, -0.4676]])\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 3;\n", " var nbb_unformatted_code = \"# 5 samples x 20 dim\\ny = torch.randn([5, 20])\\nprint(y)\";\n", " var nbb_formatted_code = \"# 5 samples x 20 dim\\ny = torch.randn([5, 20])\\nprint(y)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# 5 samples x 20 dim\n", "y = torch.randn([5, 20])\n", "print(y)" ] }, { "cell_type": "markdown", "id": "ff7d51e3", "metadata": {}, "source": [ "`PotentialODE.f` and `PotentialODE.drift` are identical. Under the hood, the following occurs:\n", "\n", "```python\n", "def drift(self, y: torch.Tensor) -> torch.Tensor:\n", " \n", " y = y.requires_grad_()\n", " ψ = self._potential(y)\n", " return self._gradient(ψ, y) * self._coef_drift\n", "```\n", "\n", "Wherein `PotentialODE._potential(y)` computes `ψ = PotentialODE.mu(y)`. `ψ` is of shape: `[y.shape[0], 1]` or `[n_samples x 1]`. While the regular `NeuralODE` computes `y(t) = net(y)`, directly, the next step for a `PotentialODE` occurs in `PotentialODE._gradient(ψ, y)`:\n", "\n", "```python\n", "y_hat = torch.autograd.grad(ψ, y, torch.ones_like(ψ), create_graph=True)[0]\n", "```\n", "\n", "We'll skip `self._coef_drift` for now." ] }, { "cell_type": "code", "execution_count": 4, "id": "6f4818b4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-0.0191, -0.0182, 0.0078, 0.0020, 0.0070, -0.0190, -0.0366, -0.0121,\n", " -0.0267, -0.0194, 0.0034, -0.0013, -0.0185, -0.0411, -0.0280, 0.0304,\n", " -0.0010, -0.0459, -0.0191, -0.0101],\n", " [-0.0118, -0.0113, 0.0267, -0.0601, -0.0217, -0.0136, -0.0339, 0.0420,\n", " 0.0257, 0.0173, -0.0147, 0.0158, -0.0204, 0.0089, -0.0070, 0.0195,\n", " -0.0142, -0.0115, 0.0208, -0.0027],\n", " [ 0.0078, -0.0325, 0.0557, -0.0108, 0.0017, -0.0035, -0.0211, -0.0244,\n", " -0.0189, -0.0022, -0.0195, -0.0373, 0.0397, -0.0341, -0.0472, 0.0457,\n", " -0.0276, -0.0150, 0.0251, 0.0113],\n", " [ 0.0117, -0.0229, 0.0043, 0.0080, -0.0022, -0.0083, -0.0038, -0.0119,\n", " -0.0275, -0.0035, 0.0126, 0.0075, 0.0523, 0.0035, -0.0165, 0.0104,\n", " 0.0084, 0.0045, 0.0319, -0.0330],\n", " [-0.0044, -0.0056, 0.0090, 0.0182, 0.0087, 0.0157, -0.0245, 0.0270,\n", " -0.0402, -0.0136, -0.0249, 0.0182, -0.0473, 0.0016, 0.0184, 0.0178,\n", " -0.0443, -0.0500, -0.0312, 0.0028]], grad_fn=)\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 4;\n", " var nbb_unformatted_code = \"y_f_hat = ODE.f(t=None, y=y)\\nprint(y_f_hat)\";\n", " var nbb_formatted_code = \"y_f_hat = ODE.f(t=None, y=y)\\nprint(y_f_hat)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "y_f_hat = ODE.f(t=None, y=y)\n", "print(y_f_hat)" ] }, { "cell_type": "code", "execution_count": 5, "id": "af3e486c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-0.0191, -0.0182, 0.0078, 0.0020, 0.0070, -0.0190, -0.0366, -0.0121,\n", " -0.0267, -0.0194, 0.0034, -0.0013, -0.0185, -0.0411, -0.0280, 0.0304,\n", " -0.0010, -0.0459, -0.0191, -0.0101],\n", " [-0.0118, -0.0113, 0.0267, -0.0601, -0.0217, -0.0136, -0.0339, 0.0420,\n", " 0.0257, 0.0173, -0.0147, 0.0158, -0.0204, 0.0089, -0.0070, 0.0195,\n", " -0.0142, -0.0115, 0.0208, -0.0027],\n", " [ 0.0078, -0.0325, 0.0557, -0.0108, 0.0017, -0.0035, -0.0211, -0.0244,\n", " -0.0189, -0.0022, -0.0195, -0.0373, 0.0397, -0.0341, -0.0472, 0.0457,\n", " -0.0276, -0.0150, 0.0251, 0.0113],\n", " [ 0.0117, -0.0229, 0.0043, 0.0080, -0.0022, -0.0083, -0.0038, -0.0119,\n", " -0.0275, -0.0035, 0.0126, 0.0075, 0.0523, 0.0035, -0.0165, 0.0104,\n", " 0.0084, 0.0045, 0.0319, -0.0330],\n", " [-0.0044, -0.0056, 0.0090, 0.0182, 0.0087, 0.0157, -0.0245, 0.0270,\n", " -0.0402, -0.0136, -0.0249, 0.0182, -0.0473, 0.0016, 0.0184, 0.0178,\n", " -0.0443, -0.0500, -0.0312, 0.0028]], grad_fn=)\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 5;\n", " var nbb_unformatted_code = \"y_f_hat = ODE.drift(y=y)\\nprint(y_f_hat)\";\n", " var nbb_formatted_code = \"y_f_hat = ODE.drift(y=y)\\nprint(y_f_hat)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "y_f_hat = ODE.drift(y=y)\n", "print(y_f_hat)" ] }, { "cell_type": "markdown", "id": "1019d7e9", "metadata": {}, "source": [ "Compute the potential, ψ of the observed state, y - i.e., (ψ(y)):" ] }, { "cell_type": "code", "execution_count": 6, "id": "83493b53", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-0.2161],\n", " [-0.1079],\n", " [-0.2329],\n", " [-0.0970],\n", " [-0.1882]], grad_fn=)\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 6;\n", " var nbb_unformatted_code = \"y_hat_potential = ODE._potential(y=y)\\nprint(y_hat_potential)\";\n", " var nbb_formatted_code = \"y_hat_potential = ODE._potential(y=y)\\nprint(y_hat_potential)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "y_hat_potential = ODE._potential(y=y)\n", "print(y_hat_potential)" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:sdq]", "language": "python", "name": "conda-env-sdq-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 }