LatentPotentialSDE.reference.ipynb#
[1]:
%load_ext nb_black
# pip install neural-diffeqs
import neural_diffeqs
print(f"Version: {neural_diffeqs.__version__}")
import torch
Version: 0.3.2
[2]:
SDE = neural_diffeqs.LatentPotentialSDE(
state_size=20, mu_hidden=[512, 512], sigma_hidden=[32, 32]
)
print(SDE)
LatentPotentialSDE(
(mu): TorchNet(
(hidden_1): Sequential(
(linear): Linear(in_features=20, out_features=512, bias=True)
(dropout): Dropout(p=0.2, inplace=False)
(activation): LeakyReLU(negative_slope=0.01)
)
(hidden_2): Sequential(
(linear): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.2, inplace=False)
(activation): LeakyReLU(negative_slope=0.01)
)
(output): Sequential(
(linear): Linear(in_features=512, out_features=20, bias=False)
)
)
(sigma): TorchNet(
(hidden_1): Sequential(
(linear): Linear(in_features=20, out_features=32, bias=True)
(dropout): Dropout(p=0.2, inplace=False)
(activation): LeakyReLU(negative_slope=0.01)
)
(hidden_2): Sequential(
(linear): Linear(in_features=32, out_features=32, bias=True)
(dropout): Dropout(p=0.2, inplace=False)
(activation): LeakyReLU(negative_slope=0.01)
)
(output): Sequential(
(linear): Linear(in_features=32, out_features=20, bias=True)
)
)
(potential): Potential(
(psi): Linear(in_features=20, out_features=1, bias=False)
)
)
Notice that the output layer of the mu function (drift network) contains only a single feature, without bias (by default):
(output): Sequential( (linear): Linear(in_features=512, out_features=1, bias=False) )
This is, so far identical to the neural_diffeqs.PotentialSDE. The key difference is the introduction of the h() function.
LatentPotentialSDE.h(y)#
[3]:
# 5 samples x 20 dim
y = torch.randn([5, 20])
print(y)
tensor([[-0.1489, -0.8965, -0.4265, -0.0423, -0.2972, -0.6784, 0.2970, -0.1280,
1.0155, -0.0928, 2.4592, 0.4539, -0.6524, 0.0763, 1.9708, 0.7134,
0.2520, 0.9877, 0.0877, 1.1103],
[-0.2822, 0.6845, -0.1269, -2.4295, 1.7630, 0.4437, -1.7529, -1.2660,
0.1211, -0.1806, -0.5824, 0.6802, -0.9001, 1.1840, 0.6217, -0.6073,
-0.6521, -0.1147, -0.5269, 0.8745],
[-1.2511, 1.0869, -1.1027, -0.6977, 0.3685, 0.2092, -0.2874, -0.3324,
-1.3541, -0.9840, -0.8346, -0.4875, -1.7272, 0.7593, 1.0221, 0.6649,
-0.0999, -0.2326, -0.1195, 0.8868],
[-1.0061, -0.0645, -0.6290, 0.4299, -0.7686, 1.6165, 0.2057, -0.4136,
-1.9558, 0.4749, -1.5010, 0.5263, 0.0259, -0.3909, 0.3455, 0.9469,
-0.9320, 1.4909, -1.5104, 1.1154],
[ 0.4129, 0.0150, 0.3230, -0.0096, -0.4123, -1.0741, -0.3676, 0.2829,
-0.4771, -0.5568, 1.7971, -0.5946, 0.3360, -0.7315, 0.5277, -0.8717,
0.9425, -0.1722, -0.1044, -0.2779]])
[4]:
# output of a function without assumption
f_out = SDE.f(None, y)
# output of a prior-regularized function
h_out = SDE.h(None, y)
[5]:
print(f"f_out shape: {f_out.shape}, h_out shape: {h_out.shape}")
f_out shape: torch.Size([5, 20]), h_out shape: torch.Size([5, 20])
torchsde.sdeint has built-in functionality to compute the KL-divergence at each predicted state and is returned alongside the predicted state:
[6]:
import torchsde
pred, kl_div = torchsde.sdeint(SDE, y, ts=torch.linspace(0, 0.1, 3), logqp=True)
print(pred.shape, kl_div.shape)
torch.Size([3, 5, 20]) torch.Size([2, 5])
For more about the PotentialSDE and how the potential function works, see the PotentialSDE notebook.