#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""sde.py: contains the different Stochastic Differential Equations (SDEs) classes: VPSDE, VESDE, subVPSDE.
The classes inherit from the SDE class.
Adapted from Jo, J. & al (2022)
"""
import abc
from typing import Callable, Optional, Sequence, Tuple, Union
import numpy as np
import torch
[docs]
class SDE(abc.ABC):
"""SDE abstract class. All functions are designed for a mini-batch of inputs."""
[docs]
def __init__(self, N: int) -> None:
"""Initialize a SDE.
Args:
N: number of discretization time steps.
"""
super().__init__()
self.N = N
@property
@abc.abstractmethod
def T(self) -> int:
"""Return the final time of the SDE.
Returns:
int: final time of the SDE.
"""
pass
[docs]
@abc.abstractmethod
def sde(
self, x: torch.Tensor, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Parameters to determine the drift and diffusion functions of the SDE, $f_t(x)$ and $G_t(x)$.
Args:
x (torch.Tensor): feature vector.
t (torch.Tensor): time step (from 0 to `self.T`).
Returns:
Tuple[torch.Tensor, torch.Tensor]: drift and diffusion.
"""
pass
[docs]
@abc.abstractmethod
def marginal_prob(
self, x: torch.Tensor, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Parameters to determine the marginal distribution of the SDE, $p_t(x)$.
Args:
x (torch.Tensor): feature vector.
t (torch.Tensor): time step (from 0 to `self.T`).
Returns:
Tuple[torch.Tensor, torch.Tensor]: mean and standard deviation of the perturbation kernel.
"""
pass
[docs]
@abc.abstractmethod
def prior_sampling(self, shape: Sequence[int]) -> torch.Tensor:
"""Generate one sample from the prior distribution, $p_T(x)$.
Args:
shape (Sequence[int]): shape of the sample.
Returns:
torch.Tensor: sample from the prior distribution.
"""
pass
[docs]
@abc.abstractmethod
def prior_logp(self, z: torch.Tensor) -> torch.Tensor:
"""Compute log-density of the prior distribution.
Useful for computing the log-likelihood via probability flow ODE.
Args:
z (torch.Tensor): latent sample.
Returns:
torch.Tensor: log probability density
"""
pass
[docs]
def discretize(
self, x: torch.Tensor, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
Useful for reverse diffusion sampling and probabiliy flow sampling.
Defaults to Euler-Maruyama discretization.
Args:
x (torch.Tensor): torch tensor
t (torch.Tensor): torch float representing the time step (from 0 to `self.T`)
Returns:
Tuple[torch.Tensor, torch.Tensor]: drift and diffusion (f, G).
"""
dt = 1 / self.N
drift, diffusion = self.sde(x, t)
f = drift * dt
G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))
return f, G
[docs]
def reverse(
self,
score_fn: Union[
Callable[
[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor],
torch.Tensor,
],
Callable[
[
torch.Tensor,
torch.Tensor,
torch.Tensor,
Optional[torch.Tensor],
torch.Tensor,
],
torch.Tensor,
],
],
probability_flow: bool = False,
is_cc: bool = False,
) -> "SDE":
"""Create the reverse-time SDE/ODE (RSDE).
Args:
score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]): time-dependent score-based model that takes x and t and returns the score.
probability_flow (bool, optional): If `True`, create the reverse-time ODE used for probability flow sampling. Defaults to False.
is_cc (bool, optional): If `True`, create the reverse-time SDE/ODE takes the rank2 incidence matrix as an input. Defaults to False.
Returns:
SDE: reverse-time SDE/ODE.
"""
N = self.N
T = self.T
self.is_cc = is_cc
sde_fn = self.sde
discretize_fn = self.discretize
# Build the class for reverse-time SDE
if not (is_cc):
class RSDE(self.__class__):
"""Reverse-time SDE/ODE."""
def __init__(self) -> None:
"""Initialize the reverse-time SDE/ODE."""
self.N = N
self.probability_flow = probability_flow
self.is_cc = is_cc
def __repr__(self) -> str:
"""Return the string representation of the reverse-time SDE/ODE.
Returns:
str: string representation of the reverse-time SDE/ODE.
"""
return f"{self.__class__.__name__}(N={self.N}, probability_flow={self.probability_flow}, T={self.T}, is_cc={self.is_cc})"
@property
def T(self) -> int:
"""Return the final time of the reverse-time SDE/ODE.
Returns:
int: final time of the reverse-time SDE/ODE.
"""
return T
def sde(
self,
feature: torch.Tensor,
x: torch.Tensor,
flags: torch.Tensor,
t: torch.Tensor,
is_adj: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns the drift and diffusion for the reverse SDE/ODE.
Args:
feature (torch.Tensor): torch tensor.
x (torch.Tensor): torch tensor.
flags (torch.Tensor): flags
t (torch.Tensor): torch float representing the time step (from 0 to `self.T`)
is_adj (bool, optional): True if reverse-SDE for the adjacency matrix. Defaults to True.
Returns:
Tuple[torch.Tensor, torch.Tensor]: drift and diffusion.
"""
drift, diffusion = sde_fn(x, t) if is_adj else sde_fn(feature, t)
score = score_fn(feature, x, flags, t)
drift = drift - diffusion[:, None, None] ** 2 * score * (
0.5 if self.probability_flow else 1.0
)
# Set the diffusion function to zero for ODEs.
diffusion = 0.0 if self.probability_flow else diffusion
return drift, diffusion
def discretize(
self,
feature: torch.Tensor,
x: torch.Tensor,
flags: torch.Tensor,
t: torch.Tensor,
is_adj: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Create discretized iteration rules for the reverse diffusion sampler.
Args:
feature (torch.Tensor): torch tensor.
x (torch.Tensor): torch tensor.
flags (torch.Tensor): flags
t (torch.Tensor): torch float representing the time step (from 0 to `self.T`)
is_adj (bool, optional): True if reverse-SDE for the adjacency matrix. Defaults to True.
Returns:
Tuple[torch.Tensor, torch.Tensor]: discretized drift and diffusion (f, G).
"""
f, G = discretize_fn(x, t) if is_adj else discretize_fn(feature, t)
score = score_fn(feature, x, flags, t)
rev_f = f - G[:, None, None] ** 2 * score * (
0.5 if self.probability_flow else 1.0
)
rev_G = torch.zeros_like(G) if self.probability_flow else G
return rev_f, rev_G
else:
class RSDE(self.__class__):
"""Reverse-time SDE/ODE."""
def __init__(self) -> None:
"""Initialize the reverse-time SDE/ODE."""
self.N = N
self.probability_flow = probability_flow
self.is_cc = is_cc
def __repr__(self) -> str:
"""Return the string representation of the reverse-time SDE/ODE.
Returns:
str: string representation of the reverse-time SDE/ODE.
"""
return f"{self.__class__.__name__}(N={self.N}, probability_flow={self.probability_flow}, T={self.T}, is_cc={self.is_cc})"
@property
def T(self) -> int:
"""Return the final time of the reverse-time SDE/ODE.
Returns:
int: final time of the reverse-time SDE/ODE.
"""
return T
def sde(
self,
feature: torch.Tensor,
x: torch.Tensor,
rank2: torch.Tensor,
flags: torch.Tensor,
t: torch.Tensor,
is_adj: bool = True,
is_rank2: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns the drift and diffusion for the reverse SDE/ODE.
Args:
feature (torch.Tensor): torch tensor.
x (torch.Tensor): torch tensor.
rank2 (torch.Tensor): torch tensor.
flags (torch.Tensor): flags
t (torch.Tensor): torch float representing the time step (from 0 to `self.T`)
is_adj (bool, optional): True if reverse-SDE for the adjacency matrix. Defaults to True.
is_rank2 (bool, optional): True if reverse-SDE for the rank2 incidence matrix. Defaults to False.
is_adj needs to be set on False too.
Returns:
Tuple[torch.Tensor, torch.Tensor]: drift and diffusion.
"""
if is_adj:
drift, diffusion = sde_fn(x, t)
elif is_rank2:
drift, diffusion = sde_fn(rank2, t)
else:
drift, diffusion = sde_fn(feature, t)
score = score_fn(feature, x, rank2, flags, t)
drift = drift - diffusion[:, None, None] ** 2 * score * (
0.5 if self.probability_flow else 1.0
)
# Set the diffusion function to zero for ODEs.
diffusion = 0.0 if self.probability_flow else diffusion
return drift, diffusion
def discretize(
self,
feature: torch.Tensor,
x: torch.Tensor,
rank2: torch.Tensor,
flags: torch.Tensor,
t: torch.Tensor,
is_adj: bool = True,
is_rank2: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Create discretized iteration rules for the reverse diffusion sampler.
Args:
feature (torch.Tensor): torch tensor.
x (torch.Tensor): torch tensor.
rank2 (torch.Tensor): torch tensor.
flags (torch.Tensor): flags
t (torch.Tensor): torch float representing the time step (from 0 to `self.T`)
is_adj (bool, optional): True if reverse-SDE for the adjacency matrix. Defaults to True.
is_rank2 (bool, optional): True if reverse-SDE for the rank2 incidence matrix. Defaults to False.
is_adj needs to be set on False too.
Returns:
Tuple[torch.Tensor, torch.Tensor]: discretized drift and diffusion (f, G).
"""
if is_adj:
f, G = discretize_fn(x, t)
elif is_rank2:
f, G = discretize_fn(rank2, t)
else:
f, G = discretize_fn(feature, t)
score = score_fn(feature, x, rank2, flags, t)
rev_f = f - G[:, None, None] ** 2 * score * (
0.5 if self.probability_flow else 1.0
)
rev_G = torch.zeros_like(G) if self.probability_flow else G
return rev_f, rev_G
return RSDE()
[docs]
class VPSDE(SDE):
"""Variance Preserving SDE (VPSDE)."""
[docs]
def __init__(
self, beta_min: float = 0.1, beta_max: float = 20.0, N: int = 1000
) -> None:
"""Construct a Variance Preserving SDE.
Args:
beta_min (float): value of beta(0)
beta_max (float): value of beta(1)
N (int): number of discretization steps
"""
super().__init__(N)
# Initialize the parameters
self.beta_0 = beta_min
self.beta_1 = beta_max
self.N = N
self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N)
self.alphas = 1.0 - self.discrete_betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_1m_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
def __repr__(self) -> str:
"""Return the string representation of the SDE.
Returns:
str: string representation of the SDE.
"""
return f"{self.__class__.__name__}(N={self.N}, beta_min={self.beta_0}, beta_max={self.beta_1}, T={self.T})"
@property
def T(self) -> int:
"""Return the final time of the SDE.
Returns:
int: final time of the SDE.
"""
return 1
[docs]
def sde(
self, x: torch.Tensor, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns the drift and diffusion for the SDE.
Args:
x (torch.Tensor): torch tensor.
t (torch.Tensor): torch float representing the time step (from 0 to `self.T`)
Returns:
Tuple[torch.Tensor, torch.Tensor]: drift and diffusion.
"""
# beta(t) = beta_min + t * (beta_max - beta_min)
beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
drift = -0.5 * beta_t[:, None, None] * x
diffusion = torch.sqrt(beta_t)
return drift, diffusion
[docs]
def marginal_prob(
self, x: torch.Tensor, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns the mean and std of the perturbation kernel.
Args:
x (torch.Tensor): torch tensor.
t (torch.Tensor): torch float representing the time step (from 0 to `self.T`)
Returns:
Tuple[torch.Tensor, torch.Tensor]: mean and std of the perturbation kernel.
"""
log_mean_coeff = (
-0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
)
mean = torch.exp(log_mean_coeff[:, None, None]) * x
std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
return mean, std
[docs]
def prior_sampling(self, shape: Sequence[int]) -> torch.Tensor:
"""Sample from the prior distribution.
Here the prior is a standard Gaussian distribution.
Args:
shape (Sequence[int]): shape of the output tensor.
Returns:
torch.Tensor: sample from the prior distribution.
"""
return torch.randn(*shape)
[docs]
def prior_sampling_sym(self, shape: Sequence[int]) -> torch.Tensor:
"""Sample from the prior distribution in the symmetric case for a matrix.
Here the prior is a standard Gaussian distribution.
Args:
shape (Sequence[int]): shape of the output tensor.
Returns:
torch.Tensor: sample from the prior distribution.
"""
x = torch.randn(*shape).triu(1)
return x + x.transpose(-1, -2)
[docs]
def prior_logp(self, z: torch.Tensor) -> torch.Tensor:
"""Returns the log probability of the prior distribution.
Args:
z (torch.Tensor): latent sample.
Returns:
torch.Tensor: log probability of the prior distribution.
"""
shape = z.shape
N = np.prod(shape[1:])
logps = -N / 2.0 * np.log(2 * np.pi) - torch.sum(z**2, dim=(1, 2)) / 2.0
return logps
[docs]
def discretize(
self, x: torch.Tensor, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""DDPM discretization for the drift and diffusion of the SDE.
Args:
x (torch.Tensor): torch tensor.
t (torch.Tensor): torch float representing the time step (from 0 to `self.T`)
Returns:
Tuple[torch.Tensor, torch.Tensor]: discretized drift and diffusion (f, G).
"""
timestep = (t * (self.N - 1) / self.T).long().to("cpu")
beta = self.discrete_betas.to(x.device)[timestep]
alpha = self.alphas.to(x.device)[timestep]
sqrt_beta = torch.sqrt(beta)
f = torch.sqrt(alpha)[:, None, None] * x - x
G = sqrt_beta
return f, G
[docs]
def transition(
self, x: torch.Tensor, t: torch.Tensor, dt: float
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns the mean and std of the transition kernel.
Args:
x (torch.Tensor): torch tensor.
t (torch.Tensor): torch float representing the time step (from 0 to `self.T`)
dt (float): time step (here negative timestep dt).
Returns:
Tuple[torch.Tensor, torch.Tensor]: mean and std of the transition kernel.
"""
log_mean_coeff = (
0.25 * dt * (2 * self.beta_0 + (2 * t + dt) * (self.beta_1 - self.beta_0))
)
mean = torch.exp(-log_mean_coeff[:, None, None]) * x
std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
return mean, std
[docs]
class VESDE(SDE):
"""Variance Exploding SDE (VESDE)."""
[docs]
def __init__(
self, sigma_min: float = 0.01, sigma_max: float = 50.0, N: int = 1000
) -> None:
"""Initialize the Variance Exploding SDE.
Args:
sigma_min (float): smallest sigma.
sigma_max (float): largest sigma.
N (int): number of discretization steps
"""
super().__init__(N)
# Initialize the parameters
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.discrete_sigmas = torch.exp(
torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)
)
self.N = N
def __repr__(self) -> str:
"""Return the string representation of the SDE.
Returns:
str: string representation of the SDE.
"""
return f"{self.__class__.__name__}(N={self.N}, sigma_min={self.sigma_min}, sigma_max={self.sigma_max}, T={self.T})"
@property
def T(self) -> int:
"""Return the final time of the SDE.
Returns:
int: final time of the SDE.
"""
return 1
[docs]
def sde(
self, x: torch.Tensor, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns the drift and diffusion of the SDE.
Args:
x (torch.Tensor): torch tensor.
t (torch.Tensor): torch float representing the time step (from 0 to `self.T`)
Returns:
Tuple[torch.Tensor, torch.Tensor]: drift and diffusion of the SDE.
"""
# sigma(t) = sigma_min * (sigma_max / sigma_min) ** t
sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
drift = torch.zeros_like(x)
diffusion = sigma * torch.sqrt(
torch.tensor(
2 * (np.log(self.sigma_max) - np.log(self.sigma_min)), device=t.device
)
)
return drift, diffusion
[docs]
def marginal_prob(
self, x: torch.Tensor, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns the mean and std of the marginal distribution at time t.
Args:
x (torch.Tensor): torch tensor.
t (torch.Tensor): torch float representing the time step (from 0 to `self.T`)
Returns:
Tuple[torch.Tensor, torch.Tensor]: mean and std of the marginal distribution.
"""
std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
mean = x
return mean, std
[docs]
def prior_sampling(self, shape: Sequence[int]) -> torch.Tensor:
"""Returns a sample from the prior distribution.
Here the prior is a standard Gaussian distribution.
Args:
shape (Sequence[int]): shape of the sample.
Returns:
torch.Tensor: sample from the prior distribution.
"""
return torch.randn(*shape)
[docs]
def prior_sampling_sym(self, shape: Sequence[int]) -> torch.Tensor:
"""Returns a sample from the prior distribution.
Here the prior is a standard Gaussian distribution.
Symmetric version of the prior sampling.
Args:
shape (Sequence[int]): shape of the sample.
Returns:
torch.Tensor: sample from the prior distribution.
"""
x = torch.randn(*shape).triu(1)
x = x + x.transpose(-1, -2)
return x
[docs]
def prior_logp(self, z: torch.Tensor) -> torch.Tensor:
"""Returns the log probability of the prior distribution.
Args:
z (torch.Tensor): latent sample.
Returns:
torch.Tensor: log probability of the prior distribution.
"""
shape = z.shape
N = np.prod(shape[1:])
return -N / 2.0 * np.log(2 * np.pi * self.sigma_max**2) - torch.sum(
z**2, dim=(1, 2, 3)
) / (2 * self.sigma_max**2)
[docs]
def discretize(
self, x: torch.Tensor, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns the drift and diffusion of the discretized SDE.
SMLD(NCSN) discretization
Args:
x (torch.Tensor): torch tensor.
t (torch.Tensor): torch float representing the time step (from 0 to `self.T`)
Returns:
Tuple[torch.Tensor, torch.Tensor]: drift and diffusion of the discretized SDE.
"""
timestep = (t * (self.N - 1) / self.T).long().to("cpu")
sigma = self.discrete_sigmas.to(t.device)[timestep]
adjacent_sigma = torch.where(
timestep.to(t.device) == 0,
torch.zeros_like(t),
self.discrete_sigmas[timestep - 1].to(t.device),
)
f = torch.zeros_like(x)
G = torch.sqrt(sigma**2 - adjacent_sigma**2)
return f, G
[docs]
def transition(
self, x: torch.Tensor, t: torch.Tensor, dt: float
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns the mean and std of the transition kernel at time t and timestep dt.
(negative timestep dt, means going backward in time)
Args:
x (torch.Tensor): torch tensor.
t (torch.Tensor): torch float representing the time step (from 0 to `self.T`)
dt (float): timestep
Returns:
Tuple[torch.Tensor, torch.Tensor]: mean and std of the transition kernel.
"""
std = torch.square(
self.sigma_min * (self.sigma_max / self.sigma_min) ** t
) - torch.square(self.sigma_min * (self.sigma_max / self.sigma_min) ** (t + dt))
std = torch.sqrt(std)
mean = x
return mean, std
[docs]
class subVPSDE(SDE):
"""Class for the sub-VP SDE that excels at likelihoods."""
[docs]
def __init__(
self, beta_min: float = 0.1, beta_max: float = 20.0, N: int = 1000
) -> None:
"""Construct the sub-VP SDE that excels at likelihoods.
Args:
beta_min (float): value of beta(0)
beta_max (float): value of beta(1)
N (int): number of discretization steps
"""
super().__init__(N)
# Initialize the parameters
self.beta_0 = beta_min
self.beta_1 = beta_max
self.N = N
self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N)
self.alphas = 1.0 - self.discrete_betas
def __repr__(self) -> str:
"""Return the string representation of the SDE.
Returns:
str: string representation of the SDE.
"""
return f"{self.__class__.__name__}(N={self.N}, beta_min={self.beta_0}, beta_max={self.beta_1}, T={self.T})"
@property
def T(self) -> int:
"""Returns the final time of the SDE.
Returns:
int: final time of the SDE.
"""
return 1
[docs]
def sde(
self, x: torch.Tensor, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns the drift and diffusion of the SDE at time t.
Args:
x (torch.Tensor): torch tensor.
t (torch.Tensor): torch float representing the time step (from 0 to `self.T`)
Returns:
Tuple[torch.Tensor, torch.Tensor]: drift and diffusion of the SDE.
"""
# beta(t) = beta_min + t * (beta_max - beta_min)
beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
drift = -0.5 * beta_t[:, None, None] * x
discount = 1.0 - torch.exp(
-2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t**2
)
diffusion = torch.sqrt(beta_t * discount)
return drift, diffusion
[docs]
def marginal_prob(
self, x: torch.Tensor, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns the mean and std of the marginal distribution at time t.
Args:
x (torch.Tensor): torch tensor.
t (torch.Tensor): torch float representing the time step (from 0 to `self.T`)
Returns:
Tuple[torch.Tensor, torch.Tensor]: mean and std of the marginal distribution.
"""
log_mean_coeff = (
-0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
)
mean = torch.exp(log_mean_coeff)[:, None, None] * x
std = 1 - torch.exp(2.0 * log_mean_coeff)
return mean, std
[docs]
def prior_sampling(self, shape: Sequence[int]) -> torch.Tensor:
"""Returns a sample from the prior distribution.
Here, the prior distribution is a standard Gaussian.
Args:
shape (Sequence[int]): shape of the sample.
Returns:
torch.Tensor: sample from the prior distribution.
"""
return torch.randn(*shape)
[docs]
def prior_sampling_sym(self, shape: Sequence[int]) -> torch.Tensor:
"""Returns a sample from the prior distribution.
Here, the prior distribution is a standard Gaussian.
Symmetric version of the prior sampling.
Args:
shape (Sequence[int]): shape of the sample.
Returns:
torch.Tensor: sample from the prior distribution.
"""
x = torch.randn(*shape).triu(1)
return x + x.transpose(-1, -2)
[docs]
def prior_logp(self, z: torch.Tensor) -> torch.Tensor:
"""Returns the log probability of the prior distribution.
Args:
z (torch.Tensor): latent sample.
Returns:
torch.Tensor: log probability of the prior distribution.
"""
shape = z.shape
N = np.prod(shape[1:])
return -N / 2.0 * np.log(2 * np.pi) - torch.sum(z**2, dim=(1, 2, 3)) / 2.0