LatentPotentialODE.reference.ipynb#

[1]:
# %load_ext nb_black
# pip install neural-diffeqs

import neural_diffeqs

print(f"Version: {neural_diffeqs.__version__}")
import torch
Version: 0.3.2
[2]:
ODE = neural_diffeqs.LatentPotentialODE(state_size=20, mu_hidden=[32, 32])
print(ODE)
LatentPotentialODE(
  (mu): TorchNet(
    (hidden_1): Sequential(
      (linear): Linear(in_features=20, out_features=32, bias=True)
      (activation): LeakyReLU(negative_slope=0.01)
    )
    (hidden_2): Sequential(
      (linear): Linear(in_features=32, out_features=32, bias=True)
      (activation): LeakyReLU(negative_slope=0.01)
    )
    (output): Sequential(
      (linear): Linear(in_features=32, out_features=20, bias=True)
    )
  )
  (potential): Potential(
    (psi): Linear(in_features=20, out_features=1, bias=False)
  )
)

Notice that the output layer of the mu function (drift network) contains only a single feature, without bias (by default):

(output): Sequential(
 (linear): Linear(in_features=512, out_features=1, bias=False)
)

This is, so far identical to the neural_diffeqs.PotentialSDE. The key difference is the introduction of the h() function.

LatentPotentialSDE.h(y)#

[3]:
# 5 samples x 20 dim
y = torch.randn([5, 20])
print(y)
tensor([[-0.8237,  1.0351,  1.7423, -1.5425,  0.4430,  0.8470, -1.4426,  0.4999,
          1.5510, -0.2683, -1.7453,  0.9307,  1.3925,  1.5210,  1.6499,  0.7581,
         -0.7673, -0.3574,  0.6198, -1.3364],
        [-0.8484, -0.0280, -0.9399, -1.2865,  0.0272,  1.3988, -0.8500,  0.0987,
         -0.0119,  0.1664,  1.4062, -0.9316, -1.2265, -0.4161,  0.3649,  0.1787,
          0.4278,  0.7464, -1.7572, -0.1976],
        [-0.5799,  0.2325,  0.0680,  0.8387, -0.1876, -2.5092,  0.0367, -0.0986,
         -0.7004, -1.6255, -0.3075, -0.4883, -0.2802,  0.5861, -0.1205,  0.2164,
         -0.9112,  0.1279,  1.9563,  0.5955],
        [-0.5629,  0.3056, -0.8121, -1.9160, -0.9585,  0.1542, -0.7999,  0.2500,
         -0.8330, -1.6756, -0.3278,  2.4593, -0.6612, -0.9197,  0.0237, -1.4666,
          0.9548,  0.7736,  0.8984, -1.3243],
        [-0.0242, -1.0799, -0.3576, -0.3413, -0.6483, -1.3423, -0.9658,  0.8313,
          0.3130, -0.5277, -1.0104,  0.3941,  0.2500, -0.1888,  0.3429, -0.9617,
         -1.5048, -0.4335, -0.3003,  0.8630]])
[4]:
# output of a function without assumption
f_out = ODE.f(None, y)

# output of a prior-regularized function
h_out = ODE.h(None, y)
[5]:
print(f"f_out shape: {f_out.shape}, h_out shape: {h_out.shape}")
f_out shape: torch.Size([5, 20]), h_out shape: torch.Size([5, 20])

torchsde.sdeint has built-in functionality to compute the KL-divergence at each predicted state and is returned alongside the predicted state:

[6]:
import torchsde

pred, kl_div = torchsde.sdeint(ODE, y, ts=torch.linspace(0, 0.1, 3), logqp=True)
print(pred.shape, kl_div.shape)
torch.Size([3, 5, 20]) torch.Size([2, 5])

For more about the PotentialSDE and how the potential function works, see the PotentialSDE notebook.