{ "cells": [ { "cell_type": "markdown", "id": "530feebe", "metadata": {}, "source": [ "`PotentialSDE.reference.ipynb` \n", "---" ] }, { "cell_type": "code", "execution_count": 1, "id": "bd0a3fa6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Version: 0.3.2\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 1;\n", " var nbb_unformatted_code = \"%load_ext nb_black\\n# pip install neural-diffeqs\\n\\nimport neural_diffeqs\\n\\nprint(f\\\"Version: {neural_diffeqs.__version__}\\\")\\nimport torch\";\n", " var nbb_formatted_code = \"%load_ext nb_black\\n# pip install neural-diffeqs\\n\\nimport neural_diffeqs\\n\\nprint(f\\\"Version: {neural_diffeqs.__version__}\\\")\\nimport torch\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# %load_ext nb_black\n", "# pip install neural-diffeqs\n", "\n", "import neural_diffeqs\n", "\n", "print(f\"Version: {neural_diffeqs.__version__}\")\n", "import torch" ] }, { "cell_type": "markdown", "id": "763caa74", "metadata": {}, "source": [ "### Default `PotentialSDE`\n", "\n", "As in the `NeuralSDE`, the only required parameter is:\n", "\n", "* `state_size`" ] }, { "cell_type": "code", "execution_count": 23, "id": "5b650fd1", "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PotentialSDE(\n", " (mu): TorchNet(\n", " (hidden_1): Sequential(\n", " (linear): Linear(in_features=20, out_features=512, bias=True)\n", " (activation): LeakyReLU(negative_slope=0.01)\n", " )\n", " (hidden_2): Sequential(\n", " (linear): Linear(in_features=512, out_features=512, bias=True)\n", " (activation): LeakyReLU(negative_slope=0.01)\n", " )\n", " (output): Sequential(\n", " (linear): Linear(in_features=512, out_features=1, bias=False)\n", " )\n", " )\n", " (sigma): TorchNet(\n", " (hidden_1): Sequential(\n", " (linear): Linear(in_features=20, out_features=32, bias=True)\n", " (activation): LeakyReLU(negative_slope=0.01)\n", " )\n", " (hidden_2): Sequential(\n", " (linear): Linear(in_features=32, out_features=32, bias=True)\n", " (activation): LeakyReLU(negative_slope=0.01)\n", " )\n", " (output): Sequential(\n", " (linear): Linear(in_features=32, out_features=20, bias=True)\n", " )\n", " )\n", ")\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 23;\n", " var nbb_unformatted_code = \"SDE = neural_diffeqs.PotentialSDE(\\n state_size=20, mu_hidden=[512, 512], sigma_hidden=[32, 32]\\n)\\nprint(SDE)\";\n", " var nbb_formatted_code = \"SDE = neural_diffeqs.PotentialSDE(\\n state_size=20, mu_hidden=[512, 512], sigma_hidden=[32, 32]\\n)\\nprint(SDE)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "SDE = neural_diffeqs.PotentialSDE(\n", " state_size=20, mu_hidden=[512, 512], sigma_hidden=[32, 32]\n", ")\n", "print(SDE)" ] }, { "cell_type": "markdown", "id": "62b0ead9", "metadata": {}, "source": [ "Notice that, the output layer of the mu function (drift network) contains only a single feature, without bias (by default):\n", "\n", ">```\n", "(output): Sequential(\n", " (linear): Linear(in_features=512, out_features=1, bias=False)\n", ")\n", "```" ] }, { "cell_type": "code", "execution_count": 24, "id": "4ddc9ae6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-0.5416, -0.2021, 0.3005, -0.5479, -0.4049, 0.0587, 0.0683, 0.2568,\n", " -0.2195, 1.7619, 1.0485, -0.3750, 1.2991, 0.3992, -0.8216, 1.2974,\n", " 0.5397, -2.5426, 1.5451, 0.8640],\n", " [ 0.7749, -0.9696, 0.0859, 0.0268, -0.3132, 0.8551, -0.4382, 1.9332,\n", " 0.7375, 0.9632, 0.2070, -0.1819, 0.2230, 0.0336, -0.8258, 0.4693,\n", " -0.7501, 0.4341, -0.7673, -0.4404],\n", " [-0.2253, 1.8405, -1.6210, 0.6966, -0.0804, -0.6984, 0.1994, -0.7314,\n", " 0.2867, -1.4236, 0.5274, 0.6056, 0.8109, 1.6812, -0.2225, -0.0951,\n", " -0.2295, 1.6923, 0.1866, -1.3322],\n", " [-0.6949, 1.0397, 0.1210, -0.4625, -0.2933, 1.4252, 1.1917, 0.6404,\n", " -0.2890, 0.2028, -1.3005, 1.2904, -1.4959, -0.2001, -0.6639, -1.1602,\n", " -0.9119, 1.0591, 0.4824, 0.4436],\n", " [-0.6839, 1.5822, -0.2486, 1.3021, 1.4632, 0.6358, 0.1735, 0.5642,\n", " -0.8587, -0.5523, -0.9715, 0.8956, -0.0801, -1.1502, 0.3116, 1.1249,\n", " 2.1384, 1.0490, -0.0455, -0.2494]])\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 24;\n", " var nbb_unformatted_code = \"# 5 samples x 20 dim\\ny = torch.randn([5, 20])\\nprint(y)\";\n", " var nbb_formatted_code = \"# 5 samples x 20 dim\\ny = torch.randn([5, 20])\\nprint(y)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# 5 samples x 20 dim\n", "y = torch.randn([5, 20])\n", "print(y)" ] }, { "cell_type": "markdown", "id": "ff7d51e3", "metadata": {}, "source": [ "`PotentialSDE.f` and `PotentialSDE.drift` are identical. Under the hood, the following occurs:\n", "\n", "```python\n", "def drift(self, y: torch.Tensor) -> torch.Tensor:\n", " \n", " y = y.requires_grad_()\n", " ψ = self._potential(y)\n", " return self._gradient(ψ, y) * self._coef_drift\n", "```\n", "\n", "Wherein `PotentialSDE._potential(y)` computes `ψ = PotentialSDE.mu(y)`. `ψ` is of shape: `[y.shape[0], 1]` or `[n_samples x 1]`. While the regular `NeuralSDE` computes `y(t) = net(y)`, directly, the next step for a `PotentialSDE` occurs in `PotentialSDE._gradient(ψ, y)`:\n", "\n", "```python\n", "y_hat = torch.autograd.grad(ψ, y, torch.ones_like(ψ), create_graph=True)[0]\n", "```\n", "\n", "We'll skip `self._coef_drift` for now." ] }, { "cell_type": "code", "execution_count": 32, "id": "6f4818b4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ 0.0005, 0.0063, 0.0015, -0.0210, 0.0196, 0.0255, -0.0005, 0.0258,\n", " 0.0339, -0.0169, 0.0299, -0.0004, 0.0085, 0.0188, 0.0123, 0.0392,\n", " -0.0136, 0.0060, -0.0170, -0.0014],\n", " [ 0.0078, 0.0026, -0.0105, -0.0021, 0.0021, -0.0016, -0.0103, -0.0044,\n", " -0.0142, 0.0098, 0.0386, -0.0320, 0.0127, 0.0068, -0.0091, 0.0138,\n", " -0.0273, 0.0189, 0.0141, 0.0097],\n", " [ 0.0094, -0.0139, 0.0119, -0.0075, 0.0208, -0.0016, 0.0210, 0.0203,\n", " -0.0203, 0.0288, -0.0045, 0.0091, 0.0089, 0.0134, 0.0088, -0.0148,\n", " 0.0058, 0.0015, -0.0235, -0.0288],\n", " [-0.0128, -0.0092, 0.0382, -0.0067, 0.0087, 0.0202, -0.0214, 0.0349,\n", " 0.0063, 0.0163, -0.0087, 0.0375, 0.0060, -0.0056, 0.0181, -0.0062,\n", " 0.0110, -0.0058, 0.0345, -0.0152],\n", " [-0.0299, -0.0129, 0.0264, 0.0022, 0.0008, 0.0360, 0.0215, -0.0047,\n", " -0.0065, 0.0089, -0.0315, 0.0285, -0.0151, 0.0198, 0.0384, 0.0388,\n", " -0.0133, 0.0080, 0.0215, 0.0101]], grad_fn=)\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 32;\n", " var nbb_unformatted_code = \"y_f_hat = SDE.f(t=None, y=y)\\nprint(y_f_hat)\";\n", " var nbb_formatted_code = \"y_f_hat = SDE.f(t=None, y=y)\\nprint(y_f_hat)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "y_f_hat = SDE.f(t=None, y=y)\n", "print(y_f_hat)" ] }, { "cell_type": "code", "execution_count": 26, "id": "af3e486c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ 0.0005, 0.0063, 0.0015, -0.0210, 0.0196, 0.0255, -0.0005, 0.0258,\n", " 0.0339, -0.0169, 0.0299, -0.0004, 0.0085, 0.0188, 0.0123, 0.0392,\n", " -0.0136, 0.0060, -0.0170, -0.0014],\n", " [ 0.0078, 0.0026, -0.0105, -0.0021, 0.0021, -0.0016, -0.0103, -0.0044,\n", " -0.0142, 0.0098, 0.0386, -0.0320, 0.0127, 0.0068, -0.0091, 0.0138,\n", " -0.0273, 0.0189, 0.0141, 0.0097],\n", " [ 0.0094, -0.0139, 0.0119, -0.0075, 0.0208, -0.0016, 0.0210, 0.0203,\n", " -0.0203, 0.0288, -0.0045, 0.0091, 0.0089, 0.0134, 0.0088, -0.0148,\n", " 0.0058, 0.0015, -0.0235, -0.0288],\n", " [-0.0128, -0.0092, 0.0382, -0.0067, 0.0087, 0.0202, -0.0214, 0.0349,\n", " 0.0063, 0.0163, -0.0087, 0.0375, 0.0060, -0.0056, 0.0181, -0.0062,\n", " 0.0110, -0.0058, 0.0345, -0.0152],\n", " [-0.0299, -0.0129, 0.0264, 0.0022, 0.0008, 0.0360, 0.0215, -0.0047,\n", " -0.0065, 0.0089, -0.0315, 0.0285, -0.0151, 0.0198, 0.0384, 0.0388,\n", " -0.0133, 0.0080, 0.0215, 0.0101]], grad_fn=)\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 26;\n", " var nbb_unformatted_code = \"y_f_hat = SDE.drift(y=y)\\nprint(y_f_hat)\";\n", " var nbb_formatted_code = \"y_f_hat = SDE.drift(y=y)\\nprint(y_f_hat)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "y_f_hat = SDE.drift(y=y)\n", "print(y_f_hat)" ] }, { "cell_type": "markdown", "id": "1019d7e9", "metadata": {}, "source": [ "Compute the potential, ψ of the observed state, y - i.e., (ψ(y)):" ] }, { "cell_type": "code", "execution_count": 31, "id": "83493b53", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ 0.0083],\n", " [ 0.0478],\n", " [-0.0464],\n", " [ 0.0688],\n", " [ 0.0761]], grad_fn=)\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 31;\n", " var nbb_unformatted_code = \"y_hat_potential = SDE._potential(y=y)\\nprint(y_hat_potential)\";\n", " var nbb_formatted_code = \"y_hat_potential = SDE._potential(y=y)\\nprint(y_hat_potential)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "y_hat_potential = SDE._potential(y=y)\n", "print(y_hat_potential)" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:sdq]", "language": "python", "name": "conda-env-sdq-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 }