{ "cells": [ { "cell_type": "markdown", "id": "530feebe", "metadata": {}, "source": [ "`NeuralSDE.reference.ipynb` \n", "---" ] }, { "cell_type": "code", "execution_count": 1, "id": "bd0a3fa6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Version: 0.3.2\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 1;\n", " var nbb_unformatted_code = \"%load_ext nb_black\\n\\nimport neural_diffeqs\\n\\nprint(f\\\"Version: {neural_diffeqs.__version__}\\\")\\nimport torch\";\n", " var nbb_formatted_code = \"%load_ext nb_black\\n\\nimport neural_diffeqs\\n\\nprint(f\\\"Version: {neural_diffeqs.__version__}\\\")\\nimport torch\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# %load_ext nb_black\n", "# pip install neural-diffeqs\n", "\n", "import neural_diffeqs\n", "\n", "print(f\"Version: {neural_diffeqs.__version__}\")\n", "import torch" ] }, { "cell_type": "markdown", "id": "763caa74", "metadata": {}, "source": [ "### Default `NeuralSDE`\n", "\n", "The only required parameter is:\n", "\n", "* `state_size`" ] }, { "cell_type": "code", "execution_count": 9, "id": "5b650fd1", "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NeuralSDE(\n", " (mu): TorchNet(\n", " (hidden_1): Sequential(\n", " (linear): Linear(in_features=50, out_features=2000, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): LeakyReLU(negative_slope=0.01)\n", " )\n", " (hidden_2): Sequential(\n", " (linear): Linear(in_features=2000, out_features=2000, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): LeakyReLU(negative_slope=0.01)\n", " )\n", " (output): Sequential(\n", " (linear): Linear(in_features=2000, out_features=50, bias=True)\n", " )\n", " )\n", " (sigma): TorchNet(\n", " (hidden_1): Sequential(\n", " (linear): Linear(in_features=50, out_features=400, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): LeakyReLU(negative_slope=0.01)\n", " )\n", " (hidden_2): Sequential(\n", " (linear): Linear(in_features=400, out_features=400, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): LeakyReLU(negative_slope=0.01)\n", " )\n", " (output): Sequential(\n", " (linear): Linear(in_features=400, out_features=50, bias=True)\n", " )\n", " )\n", ")\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 9;\n", " var nbb_unformatted_code = \"SDE = neural_diffeqs.NeuralSDE(state_size=50)\\nprint(SDE)\";\n", " var nbb_formatted_code = \"SDE = neural_diffeqs.NeuralSDE(state_size=50)\\nprint(SDE)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "SDE = neural_diffeqs.NeuralSDE(state_size=50)\n", "print(SDE)" ] }, { "cell_type": "markdown", "id": "49615c08", "metadata": {}, "source": [ "### Changing some parameters\n", "\n", "For example, specify the hidden state size for each network" ] }, { "cell_type": "code", "execution_count": 13, "id": "ed36e1dd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NeuralSDE(\n", " (mu): TorchNet(\n", " (hidden_1): Sequential(\n", " (linear): Linear(in_features=50, out_features=512, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): LeakyReLU(negative_slope=0.01)\n", " )\n", " (hidden_2): Sequential(\n", " (linear): Linear(in_features=512, out_features=512, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): LeakyReLU(negative_slope=0.01)\n", " )\n", " (output): Sequential(\n", " (linear): Linear(in_features=512, out_features=50, bias=True)\n", " )\n", " )\n", " (sigma): TorchNet(\n", " (hidden_1): Sequential(\n", " (linear): Linear(in_features=50, out_features=32, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): LeakyReLU(negative_slope=0.01)\n", " )\n", " (hidden_2): Sequential(\n", " (linear): Linear(in_features=32, out_features=32, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): LeakyReLU(negative_slope=0.01)\n", " )\n", " (output): Sequential(\n", " (linear): Linear(in_features=32, out_features=50, bias=True)\n", " )\n", " )\n", ")\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 13;\n", " var nbb_unformatted_code = \"SDE = neural_diffeqs.NeuralSDE(\\n state_size=50, mu_hidden=[512, 512], sigma_hidden=[32, 32]\\n)\\nprint(SDE)\";\n", " var nbb_formatted_code = \"SDE = neural_diffeqs.NeuralSDE(\\n state_size=50, mu_hidden=[512, 512], sigma_hidden=[32, 32]\\n)\\nprint(SDE)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "SDE = neural_diffeqs.NeuralSDE(\n", " state_size=50, mu_hidden=[512, 512], sigma_hidden=[32, 32]\n", ")\n", "print(SDE)" ] }, { "cell_type": "markdown", "id": "66dcbecc", "metadata": {}, "source": [ "### Activation functions, dropout, and bias" ] }, { "cell_type": "code", "execution_count": 14, "id": "8c490d72", "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 14;\n", " var nbb_unformatted_code = \"import torch\";\n", " var nbb_formatted_code = \"import torch\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import torch" ] }, { "cell_type": "code", "execution_count": 19, "id": "1c822988", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NeuralSDE(\n", " (mu): TorchNet(\n", " (hidden_1): Sequential(\n", " (linear): Linear(in_features=50, out_features=512, bias=True)\n", " (activation): Softmax(dim=None)\n", " )\n", " (hidden_2): Sequential(\n", " (linear): Linear(in_features=512, out_features=512, bias=True)\n", " (activation): Tanh()\n", " )\n", " (output): Sequential(\n", " (linear): Linear(in_features=512, out_features=50, bias=True)\n", " )\n", " )\n", " (sigma): TorchNet(\n", " (hidden_1): Sequential(\n", " (linear): Linear(in_features=50, out_features=32, bias=True)\n", " (activation): Softmax(dim=None)\n", " )\n", " (hidden_2): Sequential(\n", " (linear): Linear(in_features=32, out_features=32, bias=True)\n", " (activation): Tanh()\n", " )\n", " (output): Sequential(\n", " (linear): Linear(in_features=32, out_features=50, bias=True)\n", " )\n", " )\n", ")\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 19;\n", " var nbb_unformatted_code = \"SDE = neural_diffeqs.NeuralSDE(\\n state_size=50,\\n mu_hidden=[512, 512],\\n sigma_hidden=[32, 32],\\n mu_activation=[torch.nn.Softmax, torch.nn.Tanh],\\n sigma_activation=[\\\"Softmax\\\", \\\"Tanh\\\"],\\n mu_dropout=0,\\n sigma_dropout=0,\\n)\\nprint(SDE)\";\n", " var nbb_formatted_code = \"SDE = neural_diffeqs.NeuralSDE(\\n state_size=50,\\n mu_hidden=[512, 512],\\n sigma_hidden=[32, 32],\\n mu_activation=[torch.nn.Softmax, torch.nn.Tanh],\\n sigma_activation=[\\\"Softmax\\\", \\\"Tanh\\\"],\\n mu_dropout=0,\\n sigma_dropout=0,\\n)\\nprint(SDE)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "SDE = neural_diffeqs.NeuralSDE(\n", " state_size=50,\n", " mu_hidden=[512, 512],\n", " sigma_hidden=[32, 32],\n", " mu_activation=[torch.nn.Softmax, torch.nn.Tanh],\n", " sigma_activation=[\"Softmax\", \"Tanh\"],\n", " mu_dropout=0,\n", " sigma_dropout=0,\n", ")\n", "print(SDE)" ] }, { "cell_type": "code", "execution_count": 21, "id": "82b5e08a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NeuralSDE(\n", " (mu): TorchNet(\n", " (hidden_1): Sequential(\n", " (linear): Linear(in_features=50, out_features=512, bias=True)\n", " (activation): Softmax(dim=None)\n", " )\n", " (hidden_2): Sequential(\n", " (linear): Linear(in_features=512, out_features=512, bias=True)\n", " (activation): Softmax(dim=None)\n", " )\n", " (output): Sequential(\n", " (linear): Linear(in_features=512, out_features=50, bias=True)\n", " )\n", " )\n", " (sigma): TorchNet(\n", " (hidden_1): Sequential(\n", " (linear): Linear(in_features=50, out_features=32, bias=True)\n", " (activation): Softmax(dim=None)\n", " )\n", " (hidden_2): Sequential(\n", " (linear): Linear(in_features=32, out_features=32, bias=True)\n", " (activation): Softmax(dim=None)\n", " )\n", " (output): Sequential(\n", " (linear): Linear(in_features=32, out_features=50, bias=True)\n", " )\n", " )\n", ")\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 21;\n", " var nbb_unformatted_code = \"SDE = neural_diffeqs.NeuralSDE(\\n state_size=50,\\n mu_hidden=[512, 512],\\n sigma_hidden=[32, 32],\\n mu_activation=torch.nn.Softmax,\\n sigma_activation=[\\\"Softmax\\\", \\\"Softmax\\\"],\\n mu_dropout=0,\\n sigma_dropout=0,\\n)\\nprint(SDE)\";\n", " var nbb_formatted_code = \"SDE = neural_diffeqs.NeuralSDE(\\n state_size=50,\\n mu_hidden=[512, 512],\\n sigma_hidden=[32, 32],\\n mu_activation=torch.nn.Softmax,\\n sigma_activation=[\\\"Softmax\\\", \\\"Softmax\\\"],\\n mu_dropout=0,\\n sigma_dropout=0,\\n)\\nprint(SDE)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "SDE = neural_diffeqs.NeuralSDE(\n", " state_size=50,\n", " mu_hidden=[512, 512],\n", " sigma_hidden=[32, 32],\n", " mu_activation=torch.nn.Softmax,\n", " sigma_activation=[\"Softmax\", \"Softmax\"],\n", " mu_dropout=0,\n", " sigma_dropout=0,\n", ")\n", "print(SDE)" ] }, { "cell_type": "code", "execution_count": 23, "id": "43ecbe1a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NeuralSDE(\n", " (mu): TorchNet(\n", " (hidden_1): Sequential(\n", " (linear): Linear(in_features=50, out_features=512, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): Softmax(dim=None)\n", " )\n", " (hidden_2): Sequential(\n", " (linear): Linear(in_features=512, out_features=512, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (activation): Softmax(dim=None)\n", " )\n", " (output): Sequential(\n", " (linear): Linear(in_features=512, out_features=50, bias=True)\n", " )\n", " )\n", " (sigma): TorchNet(\n", " (hidden_1): Sequential(\n", " (linear): Linear(in_features=50, out_features=32, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): Softmax(dim=None)\n", " )\n", " (hidden_2): Sequential(\n", " (linear): Linear(in_features=32, out_features=32, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): Softmax(dim=None)\n", " )\n", " (output): Sequential(\n", " (linear): Linear(in_features=32, out_features=50, bias=True)\n", " )\n", " )\n", ")\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 23;\n", " var nbb_unformatted_code = \"SDE = neural_diffeqs.NeuralSDE(\\n state_size=50,\\n mu_hidden=[512, 512],\\n sigma_hidden=[32, 32],\\n mu_activation=torch.nn.Softmax,\\n sigma_activation=[\\\"Softmax\\\", \\\"Softmax\\\"],\\n mu_dropout=[0.2, 0.1],\\n sigma_dropout=0.2,\\n)\\nprint(SDE)\";\n", " var nbb_formatted_code = \"SDE = neural_diffeqs.NeuralSDE(\\n state_size=50,\\n mu_hidden=[512, 512],\\n sigma_hidden=[32, 32],\\n mu_activation=torch.nn.Softmax,\\n sigma_activation=[\\\"Softmax\\\", \\\"Softmax\\\"],\\n mu_dropout=[0.2, 0.1],\\n sigma_dropout=0.2,\\n)\\nprint(SDE)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "SDE = neural_diffeqs.NeuralSDE(\n", " state_size=50,\n", " mu_hidden=[512, 512],\n", " sigma_hidden=[32, 32],\n", " mu_activation=torch.nn.Softmax,\n", " sigma_activation=[\"Softmax\", \"Softmax\"],\n", " mu_dropout=[0.2, 0.1],\n", " sigma_dropout=0.2,\n", ")\n", "print(SDE)" ] }, { "cell_type": "code", "execution_count": 25, "id": "87be648d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NeuralSDE(\n", " (mu): TorchNet(\n", " (hidden_1): Sequential(\n", " (linear): Linear(in_features=50, out_features=512, bias=False)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): Softmax(dim=None)\n", " )\n", " (hidden_2): Sequential(\n", " (linear): Linear(in_features=512, out_features=512, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (activation): Softmax(dim=None)\n", " )\n", " (output): Sequential(\n", " (linear): Linear(in_features=512, out_features=50, bias=True)\n", " )\n", " )\n", " (sigma): TorchNet(\n", " (hidden_1): Sequential(\n", " (linear): Linear(in_features=50, out_features=32, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): Softmax(dim=None)\n", " )\n", " (hidden_2): Sequential(\n", " (linear): Linear(in_features=32, out_features=32, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): Softmax(dim=None)\n", " )\n", " (output): Sequential(\n", " (linear): Linear(in_features=32, out_features=50, bias=True)\n", " )\n", " )\n", ")\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 25;\n", " var nbb_unformatted_code = \"SDE = neural_diffeqs.NeuralSDE(\\n state_size=50,\\n mu_hidden=[512, 512],\\n sigma_hidden=[32, 32],\\n mu_activation=torch.nn.Softmax,\\n sigma_activation=[\\\"Softmax\\\", \\\"Softmax\\\"],\\n mu_dropout=[0.2, 0.1],\\n sigma_dropout=0.2,\\n mu_bias=[False, True],\\n)\\nprint(SDE)\";\n", " var nbb_formatted_code = \"SDE = neural_diffeqs.NeuralSDE(\\n state_size=50,\\n mu_hidden=[512, 512],\\n sigma_hidden=[32, 32],\\n mu_activation=torch.nn.Softmax,\\n sigma_activation=[\\\"Softmax\\\", \\\"Softmax\\\"],\\n mu_dropout=[0.2, 0.1],\\n sigma_dropout=0.2,\\n mu_bias=[False, True],\\n)\\nprint(SDE)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "SDE = neural_diffeqs.NeuralSDE(\n", " state_size=50,\n", " mu_hidden=[512, 512],\n", " sigma_hidden=[32, 32],\n", " mu_activation=torch.nn.Softmax,\n", " sigma_activation=[\"Softmax\", \"Softmax\"],\n", " mu_dropout=[0.2, 0.1],\n", " sigma_dropout=0.2,\n", " mu_bias=[False, True],\n", ")\n", "print(SDE)" ] }, { "cell_type": "code", "execution_count": 27, "id": "c1b8ffc1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NeuralSDE(\n", " (mu): TorchNet(\n", " (hidden_1): Sequential(\n", " (linear): Linear(in_features=50, out_features=512, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): Softmax(dim=None)\n", " )\n", " (hidden_2): Sequential(\n", " (linear): Linear(in_features=512, out_features=512, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (activation): Softmax(dim=None)\n", " )\n", " (output): Sequential(\n", " (linear): Linear(in_features=512, out_features=50, bias=False)\n", " )\n", " )\n", " (sigma): TorchNet(\n", " (hidden_1): Sequential(\n", " (linear): Linear(in_features=50, out_features=32, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): Softmax(dim=None)\n", " )\n", " (hidden_2): Sequential(\n", " (linear): Linear(in_features=32, out_features=32, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " (activation): Softmax(dim=None)\n", " )\n", " (output): Sequential(\n", " (linear): Linear(in_features=32, out_features=50, bias=False)\n", " )\n", " )\n", ")\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 27;\n", " var nbb_unformatted_code = \"SDE = neural_diffeqs.NeuralSDE(\\n state_size=50,\\n mu_hidden=[512, 512],\\n sigma_hidden=[32, 32],\\n mu_activation=torch.nn.Softmax,\\n sigma_activation=[\\\"Softmax\\\", \\\"Softmax\\\"],\\n mu_dropout=[0.2, 0.1],\\n sigma_dropout=0.2,\\n mu_output_bias=False,\\n sigma_output_bias = False,\\n)\\nprint(SDE)\";\n", " var nbb_formatted_code = \"SDE = neural_diffeqs.NeuralSDE(\\n state_size=50,\\n mu_hidden=[512, 512],\\n sigma_hidden=[32, 32],\\n mu_activation=torch.nn.Softmax,\\n sigma_activation=[\\\"Softmax\\\", \\\"Softmax\\\"],\\n mu_dropout=[0.2, 0.1],\\n sigma_dropout=0.2,\\n mu_output_bias=False,\\n sigma_output_bias=False,\\n)\\nprint(SDE)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "SDE = neural_diffeqs.NeuralSDE(\n", " state_size=50,\n", " mu_hidden=[512, 512],\n", " sigma_hidden=[32, 32],\n", " mu_activation=torch.nn.Softmax,\n", " sigma_activation=[\"Softmax\", \"Softmax\"],\n", " mu_dropout=[0.2, 0.1],\n", " sigma_dropout=0.2,\n", " mu_output_bias=False,\n", " sigma_output_bias=False,\n", ")\n", "print(SDE)" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:sdq]", "language": "python", "name": "conda-env-sdq-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 }