Diffusion#
All the trainers, solvers, samplers, stochastic differential equations (SDE), and losses are defined in this module.
trainer.py: code for training the model.
- class ccsd.src.trainer.Trainer(config: EasyDict | None)[source]#
Bases:
ABC
Abstract class for a Trainer.
- __init__(config: EasyDict | None) None [source]#
Initialize the trainer.
- Parameters:
config (Optional[EasyDict], optional) – the config object to use. Defaults to None.
- abstract train(ts: str) str [source]#
Train method to load the models, the optimizers, etc, train the model and save the checkpoint.
- Parameters:
ts (str) – checkpoint name (usually a timestamp)
- Returns:
checkpoint name
- Return type:
str
- class ccsd.src.trainer.Trainer_Graph(config: EasyDict)[source]#
Bases:
Trainer
Trainer class for training the model with graphs.
Adapted from Jo, J. & al (2022)
- class ccsd.src.trainer.Trainer_CC(config: EasyDict)[source]#
Bases:
Trainer
Trainer class for training the model with combinatorial complexes.
- ccsd.src.trainer.get_trainer_from_config(config: EasyDict) Trainer [source]#
Get the trainer from a configuration file config
- Parameters:
config (EasyDict) – configuration file
- Returns:
trainer to use for the experiment
- Return type:
sampler.py: code for sampling from the model.
- class ccsd.src.sampler.Sampler(config: EasyDict)[source]#
Bases:
ABC
Abstract class for Sampler objects.
- class ccsd.src.sampler.Sampler_Graph(config: EasyDict)[source]#
Bases:
Sampler
Sampler for generic graph generation tasks
Adapted from Jo, J. & al (2022)
- class ccsd.src.sampler.Sampler_CC(config: EasyDict)[source]#
Bases:
Sampler
Sampler for generic combinatorial complexes generation tasks
Adapted from Jo, J. & al (2022)
- class ccsd.src.sampler.Sampler_mol_Graph(config: EasyDict)[source]#
Bases:
Sampler
Sampler for molecule generation tasks
- class ccsd.src.sampler.Sampler_mol_CC(config: EasyDict)[source]#
Bases:
Sampler
Sampler for molecule generation tasks with combinatorial complexes
- ccsd.src.sampler.get_sampler_from_config(config: EasyDict) Sampler [source]#
Get the sampler from a configuration file config
- Parameters:
config (EasyDict) – configuration file
- Returns:
sampler to use for the experiment
- Return type:
sde.py: contains the different Stochastic Differential Equations (SDEs) classes: VPSDE, VESDE, subVPSDE. The classes inherit from the SDE class.
Adapted from Jo, J. & al (2022)
- class ccsd.src.sde.SDE(N: int)[source]#
Bases:
ABC
SDE abstract class. All functions are designed for a mini-batch of inputs.
- __init__(N: int) None [source]#
Initialize a SDE.
- Parameters:
N – number of discretization time steps.
- abstract property T: int#
Return the final time of the SDE.
- Returns:
final time of the SDE.
- Return type:
int
- abstract sde(x: Tensor, t: Tensor) Tuple[Tensor, Tensor] [source]#
Parameters to determine the drift and diffusion functions of the SDE, $f_t(x)$ and $G_t(x)$.
- Parameters:
x (torch.Tensor) – feature vector.
t (torch.Tensor) – time step (from 0 to self.T).
- Returns:
drift and diffusion.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- abstract marginal_prob(x: Tensor, t: Tensor) Tuple[Tensor, Tensor] [source]#
Parameters to determine the marginal distribution of the SDE, $p_t(x)$.
- Parameters:
x (torch.Tensor) – feature vector.
t (torch.Tensor) – time step (from 0 to self.T).
- Returns:
mean and standard deviation of the perturbation kernel.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- abstract prior_sampling(shape: Sequence[int]) Tensor [source]#
Generate one sample from the prior distribution, $p_T(x)$.
- Parameters:
shape (Sequence[int]) – shape of the sample.
- Returns:
sample from the prior distribution.
- Return type:
torch.Tensor
- abstract prior_logp(z: Tensor) Tensor [source]#
Compute log-density of the prior distribution. Useful for computing the log-likelihood via probability flow ODE.
- Parameters:
z (torch.Tensor) – latent sample.
- Returns:
log probability density
- Return type:
torch.Tensor
- discretize(x: Tensor, t: Tensor) Tuple[Tensor, Tensor] [source]#
Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i. Useful for reverse diffusion sampling and probabiliy flow sampling. Defaults to Euler-Maruyama discretization.
- Parameters:
x (torch.Tensor) – torch tensor
t (torch.Tensor) – torch float representing the time step (from 0 to self.T)
- Returns:
drift and diffusion (f, G).
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- reverse(score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], probability_flow: bool = False, is_cc: bool = False) SDE [source]#
Create the reverse-time SDE/ODE (RSDE).
- Parameters:
score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]) – time-dependent score-based model that takes x and t and returns the score.
probability_flow (bool, optional) – If True, create the reverse-time ODE used for probability flow sampling. Defaults to False.
is_cc (bool, optional) – If True, create the reverse-time SDE/ODE takes the rank2 incidence matrix as an input. Defaults to False.
- Returns:
reverse-time SDE/ODE.
- Return type:
- class ccsd.src.sde.VPSDE(beta_min: float = 0.1, beta_max: float = 20.0, N: int = 1000)[source]#
Bases:
SDE
Variance Preserving SDE (VPSDE).
- __init__(beta_min: float = 0.1, beta_max: float = 20.0, N: int = 1000) None [source]#
Construct a Variance Preserving SDE.
- Parameters:
beta_min (float) – value of beta(0)
beta_max (float) – value of beta(1)
N (int) – number of discretization steps
- property T: int#
Return the final time of the SDE.
- Returns:
final time of the SDE.
- Return type:
int
- sde(x: Tensor, t: Tensor) Tuple[Tensor, Tensor] [source]#
Returns the drift and diffusion for the SDE.
- Parameters:
x (torch.Tensor) – torch tensor.
t (torch.Tensor) – torch float representing the time step (from 0 to self.T)
- Returns:
drift and diffusion.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- marginal_prob(x: Tensor, t: Tensor) Tuple[Tensor, Tensor] [source]#
Returns the mean and std of the perturbation kernel.
- Parameters:
x (torch.Tensor) – torch tensor.
t (torch.Tensor) – torch float representing the time step (from 0 to self.T)
- Returns:
mean and std of the perturbation kernel.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- prior_sampling(shape: Sequence[int]) Tensor [source]#
Sample from the prior distribution. Here the prior is a standard Gaussian distribution.
- Parameters:
shape (Sequence[int]) – shape of the output tensor.
- Returns:
sample from the prior distribution.
- Return type:
torch.Tensor
- prior_sampling_sym(shape: Sequence[int]) Tensor [source]#
Sample from the prior distribution in the symmetric case for a matrix. Here the prior is a standard Gaussian distribution.
- Parameters:
shape (Sequence[int]) – shape of the output tensor.
- Returns:
sample from the prior distribution.
- Return type:
torch.Tensor
- prior_logp(z: Tensor) Tensor [source]#
Returns the log probability of the prior distribution.
- Parameters:
z (torch.Tensor) – latent sample.
- Returns:
log probability of the prior distribution.
- Return type:
torch.Tensor
- discretize(x: Tensor, t: Tensor) Tuple[Tensor, Tensor] [source]#
DDPM discretization for the drift and diffusion of the SDE.
- Parameters:
x (torch.Tensor) – torch tensor.
t (torch.Tensor) – torch float representing the time step (from 0 to self.T)
- Returns:
discretized drift and diffusion (f, G).
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- transition(x: Tensor, t: Tensor, dt: float) Tuple[Tensor, Tensor] [source]#
Returns the mean and std of the transition kernel.
- Parameters:
x (torch.Tensor) – torch tensor.
t (torch.Tensor) – torch float representing the time step (from 0 to self.T)
dt (float) – time step (here negative timestep dt).
- Returns:
mean and std of the transition kernel.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- class ccsd.src.sde.VESDE(sigma_min: float = 0.01, sigma_max: float = 50.0, N: int = 1000)[source]#
Bases:
SDE
Variance Exploding SDE (VESDE).
- __init__(sigma_min: float = 0.01, sigma_max: float = 50.0, N: int = 1000) None [source]#
Initialize the Variance Exploding SDE.
- Parameters:
sigma_min (float) – smallest sigma.
sigma_max (float) – largest sigma.
N (int) – number of discretization steps
- property T: int#
Return the final time of the SDE.
- Returns:
final time of the SDE.
- Return type:
int
- sde(x: Tensor, t: Tensor) Tuple[Tensor, Tensor] [source]#
Returns the drift and diffusion of the SDE.
- Parameters:
x (torch.Tensor) – torch tensor.
t (torch.Tensor) – torch float representing the time step (from 0 to self.T)
- Returns:
drift and diffusion of the SDE.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- marginal_prob(x: Tensor, t: Tensor) Tuple[Tensor, Tensor] [source]#
Returns the mean and std of the marginal distribution at time t.
- Parameters:
x (torch.Tensor) – torch tensor.
t (torch.Tensor) – torch float representing the time step (from 0 to self.T)
- Returns:
mean and std of the marginal distribution.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- prior_sampling(shape: Sequence[int]) Tensor [source]#
Returns a sample from the prior distribution. Here the prior is a standard Gaussian distribution.
- Parameters:
shape (Sequence[int]) – shape of the sample.
- Returns:
sample from the prior distribution.
- Return type:
torch.Tensor
- prior_sampling_sym(shape: Sequence[int]) Tensor [source]#
Returns a sample from the prior distribution. Here the prior is a standard Gaussian distribution. Symmetric version of the prior sampling.
- Parameters:
shape (Sequence[int]) – shape of the sample.
- Returns:
sample from the prior distribution.
- Return type:
torch.Tensor
- prior_logp(z: Tensor) Tensor [source]#
Returns the log probability of the prior distribution.
- Parameters:
z (torch.Tensor) – latent sample.
- Returns:
log probability of the prior distribution.
- Return type:
torch.Tensor
- discretize(x: Tensor, t: Tensor) Tuple[Tensor, Tensor] [source]#
Returns the drift and diffusion of the discretized SDE. SMLD(NCSN) discretization
- Parameters:
x (torch.Tensor) – torch tensor.
t (torch.Tensor) – torch float representing the time step (from 0 to self.T)
- Returns:
drift and diffusion of the discretized SDE.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- transition(x: Tensor, t: Tensor, dt: float) Tuple[Tensor, Tensor] [source]#
Returns the mean and std of the transition kernel at time t and timestep dt. (negative timestep dt, means going backward in time)
- Parameters:
x (torch.Tensor) – torch tensor.
t (torch.Tensor) – torch float representing the time step (from 0 to self.T)
dt (float) – timestep
- Returns:
mean and std of the transition kernel.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- class ccsd.src.sde.subVPSDE(beta_min: float = 0.1, beta_max: float = 20.0, N: int = 1000)[source]#
Bases:
SDE
Class for the sub-VP SDE that excels at likelihoods.
- __init__(beta_min: float = 0.1, beta_max: float = 20.0, N: int = 1000) None [source]#
Construct the sub-VP SDE that excels at likelihoods. :param beta_min: value of beta(0) :type beta_min: float :param beta_max: value of beta(1) :type beta_max: float :param N: number of discretization steps :type N: int
- property T: int#
Returns the final time of the SDE.
- Returns:
final time of the SDE.
- Return type:
int
- sde(x: Tensor, t: Tensor) Tuple[Tensor, Tensor] [source]#
Returns the drift and diffusion of the SDE at time t.
- Parameters:
x (torch.Tensor) – torch tensor.
t (torch.Tensor) – torch float representing the time step (from 0 to self.T)
- Returns:
drift and diffusion of the SDE.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- marginal_prob(x: Tensor, t: Tensor) Tuple[Tensor, Tensor] [source]#
Returns the mean and std of the marginal distribution at time t.
- Parameters:
x (torch.Tensor) – torch tensor.
t (torch.Tensor) – torch float representing the time step (from 0 to self.T)
- Returns:
mean and std of the marginal distribution.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- prior_sampling(shape: Sequence[int]) Tensor [source]#
Returns a sample from the prior distribution. Here, the prior distribution is a standard Gaussian.
- Parameters:
shape (Sequence[int]) – shape of the sample.
- Returns:
sample from the prior distribution.
- Return type:
torch.Tensor
- prior_sampling_sym(shape: Sequence[int]) Tensor [source]#
Returns a sample from the prior distribution. Here, the prior distribution is a standard Gaussian. Symmetric version of the prior sampling.
- Parameters:
shape (Sequence[int]) – shape of the sample.
- Returns:
sample from the prior distribution.
- Return type:
torch.Tensor
solver.py: Contains the SDEs solvers, and the predictor and corrector algorithms. The correctors consist of leveraging score-based MCMC methods.
Adapted from Jo, J. & al (2022) for Combinatorial Complexes and higher-order domain compatibility.
- class ccsd.src.solver.Predictor(sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], probability_flow: bool = False, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None)[source]#
Bases:
ABC
Abstract class for a predictor algorithm.
- __init__(sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], probability_flow: bool = False, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None) None [source]#
Initialize the Predictor.
- Parameters:
sde (SDE) – the SDE to solve
score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]) – the score function
probability_flow (bool, optional) – if True, use probability flow sampling. Defaults to False.
is_cc (bool, optional) – if True, get predictor for combinatorial complexes. Defaults to False.
d_min (Optional[int], optional) – minimum size of rank-2 cells (if combinatorial complexes). Defaults to None.
d_max (Optional[int], optional) – maximum size of rank-2 cells (if combinatorial complexes). Defaults to None.
- abstract update_fn(*args: Any, **kwargs: Any) Tuple[Tensor, Tensor] [source]#
Update function for the predictor class
- Parameters:
x (torch.Tensor) – tensor
adj (torch.Tensor) – adjacency matrix. Optional.
rank2 (torch.Tensor) – rank-2 tensor. Optional.
flags (torch.Tensor) – flags
t (torch.Tensor) – timestep
- Returns:
updated tensor and mean
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- class ccsd.src.solver.Corrector(sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], snr: float, scale_eps: float, n_steps: int, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None)[source]#
Bases:
ABC
Abstract class for a corrector algorithm.
- __init__(sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], snr: float, scale_eps: float, n_steps: int, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None) None [source]#
Initialize the Corrector.
- Parameters:
sde (SDE) – the SDE to solve
score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]) – the score function
snr (float) – signal-to-noise ratio
scale_eps (float) – scale of the noise
n_steps (int) – number of steps
is_cc (bool, optional) – if True, get corrector for combinatorial complexes. Defaults to False.
d_min (Optional[int], optional) – minimum size of rank-2 cells (if combinatorial complexes). Defaults to None.
d_max (Optional[int], optional) – maximum size of rank-2 cells (if combinatorial complexes). Defaults to None.
- abstract update_fn(*args: Any, **kwargs: Any) Tuple[Tensor, Tensor] [source]#
Update function for the corrector class.
- Parameters:
x (torch.Tensor) – tensor
adj (torch.Tensor) – adjacency matrix. Optional.
rank2 (torch.Tensor) – rank-2 tensor. Optional.
flags (torch.Tensor) – flags
t (torch.Tensor) – timestep
- Returns:
updated tensor and mean
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- class ccsd.src.solver.EulerMaruyamaPredictor(obj: str, sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], probability_flow: bool = False, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None)[source]#
Bases:
Predictor
Euler-Maruyama predictor.
- __init__(obj: str, sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], probability_flow: bool = False, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None) None [source]#
Initialize the Euler-Maruyama predictor.
- Parameters:
obj (str) – object to update, either “x”, “adj”, or “rank2”
sde (SDE) – the SDE to solve
score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]) – the score function
probability_flow (bool, optional) – if True, use probability flow sampling. Defaults to False.
is_cc (bool, optional) – if True, get Euler-Maruyama predictor for combinatorial complexes. Defaults to False.
d_min (Optional[int], optional) – minimum size of rank-2 cells (if combinatorial complexes). Defaults to None.
d_max (Optional[int], optional) – maximum size of rank-2 cells (if combinatorial complexes). Defaults to None.
- update_fn(*args: Any, **kwargs: Any) Tuple[Tensor, Tensor] [source]#
Update function for the Euler-Maruyama predictor.
- update_fn_graph(x: Tensor, adj: Tensor, flags: Tensor, t: Tensor) Tuple[Tensor, Tensor] [source]#
Update function for the Euler-Maruyama predictor for graphs.
- Parameters:
x (torch.Tensor) – node features
adj (torch.Tensor) – adjacency matrix
flags (torch.Tensor) – flags
t (torch.Tensor) – timestep
- Raises:
NotImplementedError – raise an error if the object to update is not recognized.
- Returns:
updated tensor and mean
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- update_fn_cc(x: Tensor, adj: Tensor, rank2: Tensor, flags: Tensor, t: Tensor) Tuple[Tensor, Tensor] [source]#
Update function for the Euler-Maruyama predictor for combinatorial complexes.
- Parameters:
x (torch.Tensor) – node features
adj (torch.Tensor) – adjacency matrix
rank2 (torch.Tensor) – rank-2 incidence matrix
flags (torch.Tensor) – flags
t (torch.Tensor) – timestep
- Raises:
NotImplementedError – raise an error if the object to update is not recognized.
- Returns:
updated tensor and mean
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- class ccsd.src.solver.ReverseDiffusionPredictor(obj: str, sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor], probability_flow: bool = False, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None)[source]#
Bases:
Predictor
Reverse diffusion predictor.
- __init__(obj: str, sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor], probability_flow: bool = False, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None)[source]#
Initialize the Reverse Diffusion predictor.
- Parameters:
obj (str) – object to update, either “x”, “adj” or “rank2”
sde (SDE) – the SDE to solve
score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]) – the score function
probability_flow (bool, optional) – if True, use probability flow sampling. Defaults to False.
is_cc (bool, optional) – if True, get Reverse Diffusion predictor for combinatorial complexes. Defaults to False.
d_min (Optional[int], optional) – minimum size of rank-2 cells (if combinatorial complexes). Defaults to None.
d_max (Optional[int], optional) – maximum size of rank-2 cells (if combinatorial complexes). Defaults to None.
- update_fn(*args: Any, **kwargs: Any) Tuple[Tensor, Tensor] [source]#
Update function for the Reverse Diffusion predictor.
- Returns:
updated tensor and mean
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- update_fn_graph(x: Tensor, adj: Tensor, flags: Tensor, t: Tensor) Tuple[Tensor, Tensor] [source]#
Update function for the Reverse Diffusion predictor for graphs.
- Parameters:
x (torch.Tensor) – node features
adj (torch.Tensor) – adjacency matrix
flags (torch.Tensor) – flags
t (torch.Tensor) – timestep
- Raises:
NotImplementedError – raise an error if the object to update is not recognized.
- Returns:
updated tensor and mean
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- update_fn_cc(x: Tensor, adj: Tensor, rank2: Tensor, flags: Tensor, t: Tensor) Tuple[Tensor, Tensor] [source]#
Update function for the Reverse Diffusion predictor for combinatorial complexes.
- Parameters:
x (torch.Tensor) – node features
adj (torch.Tensor) – adjacency matrix
rank2 (torch.Tensor) – rank-2 incidence matrix
flags (torch.Tensor) – flags
t (torch.Tensor) – timestep
- Raises:
NotImplementedError – raise an error if the object to update is not recognized.
- Returns:
updated tensor and mean
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- class ccsd.src.solver.NoneCorrector(obj: str, sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], snr: float, scale_eps: float, n_steps: int, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None)[source]#
Bases:
Corrector
An empty corrector that does nothing.
- __init__(obj: str, sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], snr: float, scale_eps: float, n_steps: int, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None)[source]#
Initialize the NoneCorrector (an empty corrector that does nothing).
- Parameters:
obj (str) – object to update, either “x” or “adj”
sde (SDE) – the SDE to solve. UNUSED HERE
score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]) – the score function. UNUSED HERE
snr (float) – signal-to-noise ratio. UNUSED HERE
scale_eps (float) – scale of the noise. UNUSED HERE
n_steps (int) – number of steps to take. UNUSED HERE
is_cc (bool, optional) – if True, get NoneCorrector for combinatorial complexes. Defaults to False.
d_min (Optional[int], optional) – minimum size of rank-2 cells (if combinatorial complexes). Defaults to None.
d_max (Optional[int], optional) – maximum size of rank-2 cells (if combinatorial complexes). Defaults to None.
- update_fn(*args: Any, **kwargs: Any) Tuple[Tensor, Tensor] [source]#
Update function for the NoneCorrector.
- Returns:
updated tensor and mean
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- update_fn_graph(x: Tensor, adj: Tensor, flags: Tensor, t: Tensor) Tuple[Tensor, Tensor] [source]#
Update function for the NoneCorrector for graphs.
- Parameters:
x (torch.Tensor) – node features
adj (torch.Tensor) – adjacency matrix
flags (torch.Tensor) – flags
t (torch.Tensor) – timestep
- Raises:
NotImplementedError – raise an error if the object to update is not recognized.
- Returns:
updated tensor and mean
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- update_fn_cc(x: Tensor, adj: Tensor, rank2: Tensor, flags: Tensor, t: Tensor) Tuple[Tensor, Tensor] [source]#
Update function for the NoneCorrector for combinatorial complexes.
- Parameters:
x (torch.Tensor) – node features
adj (torch.Tensor) – adjacency matrix
rank2 (torch.Tensor) – rank-2 incidence matrix
flags (torch.Tensor) – flags
t (torch.Tensor) – timestep
- Raises:
NotImplementedError – raise an error if the object to update is not recognized.
- Returns:
updated tensor and mean
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- class ccsd.src.solver.LangevinCorrector(obj: str, sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], snr: float, scale_eps: float, n_steps: int, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None)[source]#
Bases:
Corrector
Langevin corrector.
- __init__(obj: str, sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], snr: float, scale_eps: float, n_steps: int, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None)[source]#
Initialize the Langevin corrector.
- Parameters:
obj (str) – object to update, either “x”, “adj” or “rank2”
sde (SDE) – the SDE to solve
score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]) – the score function
snr (float) – signal-to-noise ratio
scale_eps (float) – scale of the noise
n_steps (int) – number of steps to take
is_cc (bool, optional) – if True, get Langevin corrector for combinatorial complexes. Defaults to False.
d_min (Optional[int], optional) – minimum size of rank-2 cells (if combinatorial complexes). Defaults to None.
d_max (Optional[int], optional) – maximum size of rank-2 cells (if combinatorial complexes). Defaults to None.
- update_fn(*args: Any, **kwargs: Any) Tuple[Tensor, Tensor] [source]#
Update function for the Langevin corrector.
- Returns:
updated tensor and mean
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- update_fn_graph(x: Tensor, adj: Tensor, flags: Tensor, t: Tensor) Tuple[Tensor, Tensor] [source]#
Update function for the Langevin corrector for graphs.
- Parameters:
x (torch.Tensor) – node features
adj (torch.Tensor) – adjacency matrix
flags (torch.Tensor) – flags
t (torch.Tensor) – timestep
- Raises:
NotImplementedError – raise an error if the object to update is not recognized.
- Returns:
updated tensor and mean
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- update_fn_cc(x: Tensor, adj: Tensor, rank2: Tensor, flags: Tensor, t: Tensor) Tuple[Tensor, Tensor] [source]#
Update function for the Langevin corrector for combinatorial complexes.
- Parameters:
x (torch.Tensor) – node features
adj (torch.Tensor) – adjacency matrix
rank2 (torch.Tensor) – rank-2 incidence matrix
flags (torch.Tensor) – flags
t (torch.Tensor) – timestep
- Raises:
NotImplementedError – raise an error if the object to update is not recognized.
- Returns:
updated tensor and mean
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- ccsd.src.solver.get_predictor(predictor: str) Predictor [source]#
Get the predictor function.
- Parameters:
predictor (str) – the predictor to use. Select from [Reverse, Euler].
- Raises:
NotImplementedError – raise an error if the predictor is not recognized.
- Returns:
the predictor function.
- Return type:
- ccsd.src.solver.get_corrector(corrector: str) Corrector [source]#
Get the corrector function.
- Parameters:
corrector (str) – the corrector to use. Select from [Langevin, None].
- Raises:
NotImplementedError – raise an error if the corrector is not recognized.
- Returns:
the corrector function.
- Return type:
- ccsd.src.solver.get_pc_sampler(sde_x: SDE, sde_adj: SDE, shape_x: Sequence[int], shape_adj: Sequence[int], predictor: str = 'Euler', corrector: str = 'None', snr: float = 0.1, scale_eps: float = 1.0, n_steps: int = 1, probability_flow: bool = False, continuous: bool = False, denoise: bool = True, eps: float = 0.001, device: str = 'cuda', is_cc: bool = False, sde_rank2: SDE | None = None, shape_rank2: Sequence[int] | None = None, d_min: int | None = None, d_max: int | None = None) Callable[[Module, Module, Tensor], Tuple[Tensor, Tensor, float, List[List[Tensor]]]] | Callable[[Module, Module, Module, Tensor], Tuple[Tensor, Tensor, Tensor, float, List[List[Tensor]]]] [source]#
Returns a PC sampler.
- Parameters:
sde_x (SDE) – SDE for the node features
sde_adj (SDE) – SDE for the adjacency matrix
shape_x (Sequence[int]) – shape of the node features
shape_adj (Sequence[int]) – shape of the adjacency matrix
predictor (str, optional) – predictor function. Select from [Euler, Reverse]. Defaults to “Euler”.
corrector (str, optional) – corrector function. Select from [Langevin, None]. Defaults to “None”.
snr (float, optional) – signal-to-noise ratio. Defaults to 0.1.
scale_eps (float, optional) – scale of the noise. Defaults to 1.0.
n_steps (int, optional) – number of steps to take. Defaults to 1.
probability_flow (bool, optional) – if True, use probability flow sampling. Defaults to False.
continuous (bool, optional) – if True, use continuous-time SDEs, for the score function. Defaults to False.
denoise (bool, optional) – if True, use denoising diffusion (returns the mean of the reverse SDE). Defaults to True.
eps (float, optional) – epsilon for the reverse-time SDE. Defaults to 1e-3.
device (str, optional) – device to use. Defaults to “cuda”.
is_cc (bool, optional) – if True, get PC sampler for combinatorial complexes. Defaults to False.
sde_rank2 (Optional[SDE], optional) – SDE for the higher-order features. Defaults to None.
shape_rank2 (Optional[Sequence[int]], optional) – shape of the higher-order features. Defaults to None.
d_min (Optional[int], optional) – minimum size of rank-2 cells (if combinatorial complexes). Defaults to None.
d_max (Optional[int], optional) – maximum size of rank-2 cells (if combinatorial complexes). Defaults to None.
- Returns:
PC sampler
- Return type:
Union[Callable[[torch.nn.Module, torch.nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]]]], Callable[[torch.nn.Module, torch.nn.Module, torch.nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]]]]]
- ccsd.src.solver.S4_solver(sde_x: SDE, sde_adj: SDE, shape_x: Sequence[int], shape_adj: Sequence[int], predictor: str = 'None', corrector: str = 'None', snr: float = 0.1, scale_eps: float = 1.0, n_steps: int = 1, probability_flow: bool = False, continuous: bool = False, denoise: bool = True, eps: float = 0.001, device: str = 'cuda', is_cc: bool = False, sde_rank2: SDE | None = None, shape_rank2: Sequence[int] | None = None, d_min: int | None = None, d_max: int | None = None) Callable[[Module, Module, Tensor], Tuple[Tensor, Tensor, float, List[List[Tensor]]]] | Callable[[Module, Module, Module, Tensor], Tuple[Tensor, Tensor, Tensor, float, List[List[Tensor]]]] [source]#
Returns a S4 sampler.
- Parameters:
sde_x (SDE) – SDE for the node features
sde_adj (SDE) – SDE for the adjacency matrix
shape_x (Sequence[int]) – shape of the node features
shape_adj (Sequence[int]) – shape of the adjacency matrix
predictor (str, optional) – predictor function. UNUSED HERE. Select from [Euler, Reverse]. Defaults to “None”.
corrector (str, optional) – corrector function. UNUSED HERE. Select from [Langevin, None]. Defaults to “None”.
snr (float, optional) – signal-to-noise ratio. Defaults to 0.1.
scale_eps (float, optional) – scale of the noise. Defaults to 1.0.
n_steps (int, optional) – number of steps to take. UNUSED HERE. Defaults to 1.
probability_flow (bool, optional) – if True, use probability flow sampling. UNUSED HERE. Defaults to False.
continuous (bool, optional) – if True, use continuous-time SDEs, for the score function. Defaults to False.
denoise (bool, optional) – if True, use denoising diffusion (returns the mean of the reverse SDE). Defaults to True.
eps (float, optional) – epsilon for the reverse-time SDE. Defaults to 1e-3.
device (str, optional) – device to use. Defaults to “cuda”.
is_cc (bool, optional) – if True, get S4 sampler for combinatorial complexes. Defaults to False.
sde_rank2 (Optional[SDE], optional) – SDE for the higher-order features. Defaults to None.
shape_rank2 (Optional[Sequence[int]], optional) – shape of the higher-order features. Defaults to None.
d_min (Optional[int], optional) – minimum size of rank-2 cells (if combinatorial complexes). Defaults to None.
d_max (Optional[int], optional) – maximum size of rank-2 cells (if combinatorial complexes). Defaults to None.
- Returns:
S4 sampler
- Return type:
Union[Callable[[torch.nn.Module, torch.nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]]]], Callable[[torch.nn.Module, torch.nn.Module, torch.nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]]]]]
losses.py: Loss functions for training the SDEs.
Adapted from Jo, J. & al (2022), except for get_sde_loss_fn_cc, almost left untouched.
- ccsd.src.losses.get_score_fn(sde: SDE, model: Module, train: bool = True, continuous: bool = True) Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] [source]#
Return the score function for the SDE and the model.
- Parameters:
sde (SDE) – Stochastic Differential Equation (SDE)
model (torch.nn.Module) – neural network model that predicts the score
train (bool, optional) – whether or not we train the model. Defaults to True.
continuous (bool, optional) – if the SDE is continuous (discrete NOT IMPLEMENTED HERE). Defaults to True.
- Raises:
NotImplementedError – raise an error if the SDE is not implemented
- Returns:
score function
- Return type:
Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], float], torch.Tensor]
- ccsd.src.losses.get_score_fn_cc(sde: SDE, model: Module, train: bool = True, continuous: bool = True) Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor] [source]#
Return the score function for the SDE and the model.
- Parameters:
sde (SDE) – Stochastic Differential Equation (SDE)
model (torch.nn.Module) – neural network model that predicts the score
train (bool, optional) – whether or not we train the model. Defaults to True.
continuous (bool, optional) – if the SDE is continuous (discrete NOT IMPLEMENTED HERE). Defaults to True.
- Raises:
NotImplementedError – raise an error if the SDE is not implemented
- Returns:
score function
- Return type:
Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], float], torch.Tensor]
- ccsd.src.losses.get_sde_loss_fn(sde_x: SDE, sde_adj: SDE, train: bool = True, reduce_mean: bool = False, continuous: bool = True, likelihood_weighting: bool = False, eps: float = 1e-05) Callable[[Module, Module, Tensor, Tensor], Tuple[Tensor, Tensor]] [source]#
Return the loss function for the SDEs with specific parameters.
- Parameters:
sde_x (SDE) – SDE for node features
sde_adj (SDE) – SDE for adjacency matrix
train (bool, optional) – whether or not we are training the model. Defaults to True.
reduce_mean (bool, optional) – if True, we reduce the loss by first taking the mean along the last axis. Defaults to False.
continuous (bool, optional) – if the SDE is continuous. Defaults to True.
likelihood_weighting (bool, optional) – if True, weight the loss with standard deviations. Defaults to False.
eps (float, optional) – parameter for sampling time. Defaults to 1e-5.
- Returns:
loss function
- Return type:
Callable[[torch.nn.Module, torch.nn.Module, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]
- ccsd.src.losses.get_sde_loss_fn_cc(sde_x: SDE, sde_adj: SDE, sde_rank2: SDE, d_min: int, d_max: int, train: bool = True, reduce_mean: bool = False, continuous: bool = True, likelihood_weighting: bool = False, eps: float = 1e-05) Callable[[Module, Module, Module, Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]] [source]#
Return the loss function for the SDEs with specific parameters.
- Parameters:
sde_x (SDE) – SDE for node features
sde_adj (SDE) – SDE for adjacency matrix
sde_rank2 (SDE) – SDE for rank-2 incidence tensor
d_min (int) – minimum size of the rank-2 cells
d_max (int) – maximum size of the rank-2 cells
train (bool, optional) – whether or not we are training the model. Defaults to True.
reduce_mean (bool, optional) – if True, we reduce the loss by first taking the mean along the last axis. Defaults to False.
continuous (bool, optional) – if the SDE is continuous. Defaults to True.
likelihood_weighting (bool, optional) – if True, weight the loss with standard deviations. Defaults to False.
eps (float, optional) – parameter for sampling time. Defaults to 1e-5.
- Returns:
loss function
- Return type:
Callable[[torch.nn.Module, torch.nn.Module, torch.nn.Module, torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]