LatentPotentialSDE.reference.ipynb#

[1]:
%load_ext nb_black
# pip install neural-diffeqs

import neural_diffeqs

print(f"Version: {neural_diffeqs.__version__}")
import torch
Version: 0.3.2
[2]:
SDE = neural_diffeqs.LatentPotentialSDE(
    state_size=20, mu_hidden=[512, 512], sigma_hidden=[32, 32]
)
print(SDE)
LatentPotentialSDE(
  (mu): TorchNet(
    (hidden_1): Sequential(
      (linear): Linear(in_features=20, out_features=512, bias=True)
      (dropout): Dropout(p=0.2, inplace=False)
      (activation): LeakyReLU(negative_slope=0.01)
    )
    (hidden_2): Sequential(
      (linear): Linear(in_features=512, out_features=512, bias=True)
      (dropout): Dropout(p=0.2, inplace=False)
      (activation): LeakyReLU(negative_slope=0.01)
    )
    (output): Sequential(
      (linear): Linear(in_features=512, out_features=20, bias=False)
    )
  )
  (sigma): TorchNet(
    (hidden_1): Sequential(
      (linear): Linear(in_features=20, out_features=32, bias=True)
      (dropout): Dropout(p=0.2, inplace=False)
      (activation): LeakyReLU(negative_slope=0.01)
    )
    (hidden_2): Sequential(
      (linear): Linear(in_features=32, out_features=32, bias=True)
      (dropout): Dropout(p=0.2, inplace=False)
      (activation): LeakyReLU(negative_slope=0.01)
    )
    (output): Sequential(
      (linear): Linear(in_features=32, out_features=20, bias=True)
    )
  )
  (potential): Potential(
    (psi): Linear(in_features=20, out_features=1, bias=False)
  )
)

Notice that the output layer of the mu function (drift network) contains only a single feature, without bias (by default):

(output): Sequential(
 (linear): Linear(in_features=512, out_features=1, bias=False)
)

This is, so far identical to the neural_diffeqs.PotentialSDE. The key difference is the introduction of the h() function.

LatentPotentialSDE.h(y)#

[3]:
# 5 samples x 20 dim
y = torch.randn([5, 20])
print(y)
tensor([[-0.1489, -0.8965, -0.4265, -0.0423, -0.2972, -0.6784,  0.2970, -0.1280,
          1.0155, -0.0928,  2.4592,  0.4539, -0.6524,  0.0763,  1.9708,  0.7134,
          0.2520,  0.9877,  0.0877,  1.1103],
        [-0.2822,  0.6845, -0.1269, -2.4295,  1.7630,  0.4437, -1.7529, -1.2660,
          0.1211, -0.1806, -0.5824,  0.6802, -0.9001,  1.1840,  0.6217, -0.6073,
         -0.6521, -0.1147, -0.5269,  0.8745],
        [-1.2511,  1.0869, -1.1027, -0.6977,  0.3685,  0.2092, -0.2874, -0.3324,
         -1.3541, -0.9840, -0.8346, -0.4875, -1.7272,  0.7593,  1.0221,  0.6649,
         -0.0999, -0.2326, -0.1195,  0.8868],
        [-1.0061, -0.0645, -0.6290,  0.4299, -0.7686,  1.6165,  0.2057, -0.4136,
         -1.9558,  0.4749, -1.5010,  0.5263,  0.0259, -0.3909,  0.3455,  0.9469,
         -0.9320,  1.4909, -1.5104,  1.1154],
        [ 0.4129,  0.0150,  0.3230, -0.0096, -0.4123, -1.0741, -0.3676,  0.2829,
         -0.4771, -0.5568,  1.7971, -0.5946,  0.3360, -0.7315,  0.5277, -0.8717,
          0.9425, -0.1722, -0.1044, -0.2779]])
[4]:
# output of a function without assumption
f_out = SDE.f(None, y)

# output of a prior-regularized function
h_out = SDE.h(None, y)
[5]:
print(f"f_out shape: {f_out.shape}, h_out shape: {h_out.shape}")
f_out shape: torch.Size([5, 20]), h_out shape: torch.Size([5, 20])

torchsde.sdeint has built-in functionality to compute the KL-divergence at each predicted state and is returned alongside the predicted state:

[6]:
import torchsde

pred, kl_div = torchsde.sdeint(SDE, y, ts=torch.linspace(0, 0.1, 3), logqp=True)
print(pred.shape, kl_div.shape)
torch.Size([3, 5, 20]) torch.Size([2, 5])

For more about the PotentialSDE and how the potential function works, see the PotentialSDE notebook.