Diffusion#

All the trainers, solvers, samplers, stochastic differential equations (SDE), and losses are defined in this module.

trainer.py: code for training the model.

class ccsd.src.trainer.Trainer(config: EasyDict | None)[source]#

Bases: ABC

Abstract class for a Trainer.

__init__(config: EasyDict | None) None[source]#

Initialize the trainer.

Parameters:

config (Optional[EasyDict], optional) – the config object to use. Defaults to None.

abstract train(ts: str) str[source]#

Train method to load the models, the optimizers, etc, train the model and save the checkpoint.

Parameters:

ts (str) – checkpoint name (usually a timestamp)

Returns:

checkpoint name

Return type:

str

save_learning_curves(learning_curves: Dict[str, List[float]]) None[source]#

Save the learning curves in a .npy file.

Parameters:

learning_curves (Dict[str, List[float]]) – the learning curves to save

plot_learning_curves(learning_curves: Dict[str, List[float]]) None[source]#

Plot the learning curves.

Parameters:

learning_curves (Dict[str, List[float]]) – the learning curves to plot

class ccsd.src.trainer.Trainer_Graph(config: EasyDict)[source]#

Bases: Trainer

Trainer class for training the model with graphs.

Adapted from Jo, J. & al (2022)

__init__(config: EasyDict) None[source]#

Initialize the trainer with the different configs.

Parameters:

config (EasyDict) – the config object to use

train(ts: str) str[source]#

Train method to load the models, the optimizers, etc, train the model and save the checkpoint.

Parameters:

ts (str) – checkpoint name (usually a timestamp)

Returns:

checkpoint name

Return type:

str

class ccsd.src.trainer.Trainer_CC(config: EasyDict)[source]#

Bases: Trainer

Trainer class for training the model with combinatorial complexes.

__init__(config: EasyDict) None[source]#

Initialize the trainer with the different configs.

Parameters:

config (EasyDict) – the config object to use

train(ts: str) str[source]#

Train method to load the models, the optimizers, etc, train the model and save the checkpoint.

Parameters:

ts (str) – checkpoint name (usually a timestamp)

Returns:

checkpoint name

Return type:

str

ccsd.src.trainer.get_trainer_from_config(config: EasyDict) Trainer[source]#

Get the trainer from a configuration file config

Parameters:

config (EasyDict) – configuration file

Returns:

trainer to use for the experiment

Return type:

Trainer

sampler.py: code for sampling from the model.

class ccsd.src.sampler.Sampler(config: EasyDict)[source]#

Bases: ABC

Abstract class for Sampler objects.

__init__(config: EasyDict) None[source]#

Initialize the sampler.

Parameters:

config (EasyDict) – the config object to use

abstract sample() None[source]#

Sample from the model. Loads the checkpoint, load the modes, generates samples, evaluates, saves and plot them.

class ccsd.src.sampler.Sampler_Graph(config: EasyDict)[source]#

Bases: Sampler

Sampler for generic graph generation tasks

Adapted from Jo, J. & al (2022)

__init__(config: EasyDict) None[source]#

Initialize the sampler with the config and the device.

Parameters:

config (EasyDict) – the config object to use

sample() None[source]#

Sample from the model. Loads the checkpoint, load the modes, generates samples, evaluates, saves and plot them.

class ccsd.src.sampler.Sampler_CC(config: EasyDict)[source]#

Bases: Sampler

Sampler for generic combinatorial complexes generation tasks

Adapted from Jo, J. & al (2022)

__init__(config: EasyDict) None[source]#

Initialize the sampler with the config and the device.

Parameters:

config (EasyDict) – the config object to use

sample() None[source]#

Sample from the model. Loads the checkpoint, load the modes, generates samples, evaluates, saves and plot them.

class ccsd.src.sampler.Sampler_mol_Graph(config: EasyDict)[source]#

Bases: Sampler

Sampler for molecule generation tasks

__init__(config: EasyDict) None[source]#

Initialize the sampler with the config and the device.

Parameters:

config (EasyDict) – the config object to use

sample() None[source]#

Sample from the model. Loads the checkpoint, load the modes, generates samples, evaluates and saves them.

class ccsd.src.sampler.Sampler_mol_CC(config: EasyDict)[source]#

Bases: Sampler

Sampler for molecule generation tasks with combinatorial complexes

__init__(config: EasyDict) None[source]#

Initialize the sampler with the config and the device.

Parameters:

config (EasyDict) – the config object to use

sample() None[source]#

Sample from the model. Loads the checkpoint, load the modes, generates samples, evaluates and saves them.

ccsd.src.sampler.get_sampler_from_config(config: EasyDict) Sampler[source]#

Get the sampler from a configuration file config

Parameters:

config (EasyDict) – configuration file

Returns:

sampler to use for the experiment

Return type:

Sampler

sde.py: contains the different Stochastic Differential Equations (SDEs) classes: VPSDE, VESDE, subVPSDE. The classes inherit from the SDE class.

Adapted from Jo, J. & al (2022)

class ccsd.src.sde.SDE(N: int)[source]#

Bases: ABC

SDE abstract class. All functions are designed for a mini-batch of inputs.

__init__(N: int) None[source]#

Initialize a SDE.

Parameters:

N – number of discretization time steps.

abstract property T: int#

Return the final time of the SDE.

Returns:

final time of the SDE.

Return type:

int

abstract sde(x: Tensor, t: Tensor) Tuple[Tensor, Tensor][source]#

Parameters to determine the drift and diffusion functions of the SDE, $f_t(x)$ and $G_t(x)$.

Parameters:
  • x (torch.Tensor) – feature vector.

  • t (torch.Tensor) – time step (from 0 to self.T).

Returns:

drift and diffusion.

Return type:

Tuple[torch.Tensor, torch.Tensor]

abstract marginal_prob(x: Tensor, t: Tensor) Tuple[Tensor, Tensor][source]#

Parameters to determine the marginal distribution of the SDE, $p_t(x)$.

Parameters:
  • x (torch.Tensor) – feature vector.

  • t (torch.Tensor) – time step (from 0 to self.T).

Returns:

mean and standard deviation of the perturbation kernel.

Return type:

Tuple[torch.Tensor, torch.Tensor]

abstract prior_sampling(shape: Sequence[int]) Tensor[source]#

Generate one sample from the prior distribution, $p_T(x)$.

Parameters:

shape (Sequence[int]) – shape of the sample.

Returns:

sample from the prior distribution.

Return type:

torch.Tensor

abstract prior_logp(z: Tensor) Tensor[source]#

Compute log-density of the prior distribution. Useful for computing the log-likelihood via probability flow ODE.

Parameters:

z (torch.Tensor) – latent sample.

Returns:

log probability density

Return type:

torch.Tensor

discretize(x: Tensor, t: Tensor) Tuple[Tensor, Tensor][source]#

Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i. Useful for reverse diffusion sampling and probabiliy flow sampling. Defaults to Euler-Maruyama discretization.

Parameters:
  • x (torch.Tensor) – torch tensor

  • t (torch.Tensor) – torch float representing the time step (from 0 to self.T)

Returns:

drift and diffusion (f, G).

Return type:

Tuple[torch.Tensor, torch.Tensor]

reverse(score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], probability_flow: bool = False, is_cc: bool = False) SDE[source]#

Create the reverse-time SDE/ODE (RSDE).

Parameters:
  • score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]) – time-dependent score-based model that takes x and t and returns the score.

  • probability_flow (bool, optional) – If True, create the reverse-time ODE used for probability flow sampling. Defaults to False.

  • is_cc (bool, optional) – If True, create the reverse-time SDE/ODE takes the rank2 incidence matrix as an input. Defaults to False.

Returns:

reverse-time SDE/ODE.

Return type:

SDE

class ccsd.src.sde.VPSDE(beta_min: float = 0.1, beta_max: float = 20.0, N: int = 1000)[source]#

Bases: SDE

Variance Preserving SDE (VPSDE).

__init__(beta_min: float = 0.1, beta_max: float = 20.0, N: int = 1000) None[source]#

Construct a Variance Preserving SDE.

Parameters:
  • beta_min (float) – value of beta(0)

  • beta_max (float) – value of beta(1)

  • N (int) – number of discretization steps

property T: int#

Return the final time of the SDE.

Returns:

final time of the SDE.

Return type:

int

sde(x: Tensor, t: Tensor) Tuple[Tensor, Tensor][source]#

Returns the drift and diffusion for the SDE.

Parameters:
  • x (torch.Tensor) – torch tensor.

  • t (torch.Tensor) – torch float representing the time step (from 0 to self.T)

Returns:

drift and diffusion.

Return type:

Tuple[torch.Tensor, torch.Tensor]

marginal_prob(x: Tensor, t: Tensor) Tuple[Tensor, Tensor][source]#

Returns the mean and std of the perturbation kernel.

Parameters:
  • x (torch.Tensor) – torch tensor.

  • t (torch.Tensor) – torch float representing the time step (from 0 to self.T)

Returns:

mean and std of the perturbation kernel.

Return type:

Tuple[torch.Tensor, torch.Tensor]

prior_sampling(shape: Sequence[int]) Tensor[source]#

Sample from the prior distribution. Here the prior is a standard Gaussian distribution.

Parameters:

shape (Sequence[int]) – shape of the output tensor.

Returns:

sample from the prior distribution.

Return type:

torch.Tensor

prior_sampling_sym(shape: Sequence[int]) Tensor[source]#

Sample from the prior distribution in the symmetric case for a matrix. Here the prior is a standard Gaussian distribution.

Parameters:

shape (Sequence[int]) – shape of the output tensor.

Returns:

sample from the prior distribution.

Return type:

torch.Tensor

prior_logp(z: Tensor) Tensor[source]#

Returns the log probability of the prior distribution.

Parameters:

z (torch.Tensor) – latent sample.

Returns:

log probability of the prior distribution.

Return type:

torch.Tensor

discretize(x: Tensor, t: Tensor) Tuple[Tensor, Tensor][source]#

DDPM discretization for the drift and diffusion of the SDE.

Parameters:
  • x (torch.Tensor) – torch tensor.

  • t (torch.Tensor) – torch float representing the time step (from 0 to self.T)

Returns:

discretized drift and diffusion (f, G).

Return type:

Tuple[torch.Tensor, torch.Tensor]

transition(x: Tensor, t: Tensor, dt: float) Tuple[Tensor, Tensor][source]#

Returns the mean and std of the transition kernel.

Parameters:
  • x (torch.Tensor) – torch tensor.

  • t (torch.Tensor) – torch float representing the time step (from 0 to self.T)

  • dt (float) – time step (here negative timestep dt).

Returns:

mean and std of the transition kernel.

Return type:

Tuple[torch.Tensor, torch.Tensor]

class ccsd.src.sde.VESDE(sigma_min: float = 0.01, sigma_max: float = 50.0, N: int = 1000)[source]#

Bases: SDE

Variance Exploding SDE (VESDE).

__init__(sigma_min: float = 0.01, sigma_max: float = 50.0, N: int = 1000) None[source]#

Initialize the Variance Exploding SDE.

Parameters:
  • sigma_min (float) – smallest sigma.

  • sigma_max (float) – largest sigma.

  • N (int) – number of discretization steps

property T: int#

Return the final time of the SDE.

Returns:

final time of the SDE.

Return type:

int

sde(x: Tensor, t: Tensor) Tuple[Tensor, Tensor][source]#

Returns the drift and diffusion of the SDE.

Parameters:
  • x (torch.Tensor) – torch tensor.

  • t (torch.Tensor) – torch float representing the time step (from 0 to self.T)

Returns:

drift and diffusion of the SDE.

Return type:

Tuple[torch.Tensor, torch.Tensor]

marginal_prob(x: Tensor, t: Tensor) Tuple[Tensor, Tensor][source]#

Returns the mean and std of the marginal distribution at time t.

Parameters:
  • x (torch.Tensor) – torch tensor.

  • t (torch.Tensor) – torch float representing the time step (from 0 to self.T)

Returns:

mean and std of the marginal distribution.

Return type:

Tuple[torch.Tensor, torch.Tensor]

prior_sampling(shape: Sequence[int]) Tensor[source]#

Returns a sample from the prior distribution. Here the prior is a standard Gaussian distribution.

Parameters:

shape (Sequence[int]) – shape of the sample.

Returns:

sample from the prior distribution.

Return type:

torch.Tensor

prior_sampling_sym(shape: Sequence[int]) Tensor[source]#

Returns a sample from the prior distribution. Here the prior is a standard Gaussian distribution. Symmetric version of the prior sampling.

Parameters:

shape (Sequence[int]) – shape of the sample.

Returns:

sample from the prior distribution.

Return type:

torch.Tensor

prior_logp(z: Tensor) Tensor[source]#

Returns the log probability of the prior distribution.

Parameters:

z (torch.Tensor) – latent sample.

Returns:

log probability of the prior distribution.

Return type:

torch.Tensor

discretize(x: Tensor, t: Tensor) Tuple[Tensor, Tensor][source]#

Returns the drift and diffusion of the discretized SDE. SMLD(NCSN) discretization

Parameters:
  • x (torch.Tensor) – torch tensor.

  • t (torch.Tensor) – torch float representing the time step (from 0 to self.T)

Returns:

drift and diffusion of the discretized SDE.

Return type:

Tuple[torch.Tensor, torch.Tensor]

transition(x: Tensor, t: Tensor, dt: float) Tuple[Tensor, Tensor][source]#

Returns the mean and std of the transition kernel at time t and timestep dt. (negative timestep dt, means going backward in time)

Parameters:
  • x (torch.Tensor) – torch tensor.

  • t (torch.Tensor) – torch float representing the time step (from 0 to self.T)

  • dt (float) – timestep

Returns:

mean and std of the transition kernel.

Return type:

Tuple[torch.Tensor, torch.Tensor]

class ccsd.src.sde.subVPSDE(beta_min: float = 0.1, beta_max: float = 20.0, N: int = 1000)[source]#

Bases: SDE

Class for the sub-VP SDE that excels at likelihoods.

__init__(beta_min: float = 0.1, beta_max: float = 20.0, N: int = 1000) None[source]#

Construct the sub-VP SDE that excels at likelihoods. :param beta_min: value of beta(0) :type beta_min: float :param beta_max: value of beta(1) :type beta_max: float :param N: number of discretization steps :type N: int

property T: int#

Returns the final time of the SDE.

Returns:

final time of the SDE.

Return type:

int

sde(x: Tensor, t: Tensor) Tuple[Tensor, Tensor][source]#

Returns the drift and diffusion of the SDE at time t.

Parameters:
  • x (torch.Tensor) – torch tensor.

  • t (torch.Tensor) – torch float representing the time step (from 0 to self.T)

Returns:

drift and diffusion of the SDE.

Return type:

Tuple[torch.Tensor, torch.Tensor]

marginal_prob(x: Tensor, t: Tensor) Tuple[Tensor, Tensor][source]#

Returns the mean and std of the marginal distribution at time t.

Parameters:
  • x (torch.Tensor) – torch tensor.

  • t (torch.Tensor) – torch float representing the time step (from 0 to self.T)

Returns:

mean and std of the marginal distribution.

Return type:

Tuple[torch.Tensor, torch.Tensor]

prior_sampling(shape: Sequence[int]) Tensor[source]#

Returns a sample from the prior distribution. Here, the prior distribution is a standard Gaussian.

Parameters:

shape (Sequence[int]) – shape of the sample.

Returns:

sample from the prior distribution.

Return type:

torch.Tensor

prior_sampling_sym(shape: Sequence[int]) Tensor[source]#

Returns a sample from the prior distribution. Here, the prior distribution is a standard Gaussian. Symmetric version of the prior sampling.

Parameters:

shape (Sequence[int]) – shape of the sample.

Returns:

sample from the prior distribution.

Return type:

torch.Tensor

prior_logp(z: Tensor) Tensor[source]#

Returns the log probability of the prior distribution.

Parameters:

z (torch.Tensor) – latent sample.

Returns:

log probability of the prior distribution.

Return type:

torch.Tensor

solver.py: Contains the SDEs solvers, and the predictor and corrector algorithms. The correctors consist of leveraging score-based MCMC methods.

Adapted from Jo, J. & al (2022) for Combinatorial Complexes and higher-order domain compatibility.

class ccsd.src.solver.Predictor(sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], probability_flow: bool = False, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None)[source]#

Bases: ABC

Abstract class for a predictor algorithm.

__init__(sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], probability_flow: bool = False, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None) None[source]#

Initialize the Predictor.

Parameters:
  • sde (SDE) – the SDE to solve

  • score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]) – the score function

  • probability_flow (bool, optional) – if True, use probability flow sampling. Defaults to False.

  • is_cc (bool, optional) – if True, get predictor for combinatorial complexes. Defaults to False.

  • d_min (Optional[int], optional) – minimum size of rank-2 cells (if combinatorial complexes). Defaults to None.

  • d_max (Optional[int], optional) – maximum size of rank-2 cells (if combinatorial complexes). Defaults to None.

abstract update_fn(*args: Any, **kwargs: Any) Tuple[Tensor, Tensor][source]#

Update function for the predictor class

Parameters:
  • x (torch.Tensor) – tensor

  • adj (torch.Tensor) – adjacency matrix. Optional.

  • rank2 (torch.Tensor) – rank-2 tensor. Optional.

  • flags (torch.Tensor) – flags

  • t (torch.Tensor) – timestep

Returns:

updated tensor and mean

Return type:

Tuple[torch.Tensor, torch.Tensor]

class ccsd.src.solver.Corrector(sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], snr: float, scale_eps: float, n_steps: int, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None)[source]#

Bases: ABC

Abstract class for a corrector algorithm.

__init__(sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], snr: float, scale_eps: float, n_steps: int, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None) None[source]#

Initialize the Corrector.

Parameters:
  • sde (SDE) – the SDE to solve

  • score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]) – the score function

  • snr (float) – signal-to-noise ratio

  • scale_eps (float) – scale of the noise

  • n_steps (int) – number of steps

  • is_cc (bool, optional) – if True, get corrector for combinatorial complexes. Defaults to False.

  • d_min (Optional[int], optional) – minimum size of rank-2 cells (if combinatorial complexes). Defaults to None.

  • d_max (Optional[int], optional) – maximum size of rank-2 cells (if combinatorial complexes). Defaults to None.

abstract update_fn(*args: Any, **kwargs: Any) Tuple[Tensor, Tensor][source]#

Update function for the corrector class.

Parameters:
  • x (torch.Tensor) – tensor

  • adj (torch.Tensor) – adjacency matrix. Optional.

  • rank2 (torch.Tensor) – rank-2 tensor. Optional.

  • flags (torch.Tensor) – flags

  • t (torch.Tensor) – timestep

Returns:

updated tensor and mean

Return type:

Tuple[torch.Tensor, torch.Tensor]

class ccsd.src.solver.EulerMaruyamaPredictor(obj: str, sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], probability_flow: bool = False, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None)[source]#

Bases: Predictor

Euler-Maruyama predictor.

__init__(obj: str, sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], probability_flow: bool = False, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None) None[source]#

Initialize the Euler-Maruyama predictor.

Parameters:
  • obj (str) – object to update, either “x”, “adj”, or “rank2”

  • sde (SDE) – the SDE to solve

  • score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]) – the score function

  • probability_flow (bool, optional) – if True, use probability flow sampling. Defaults to False.

  • is_cc (bool, optional) – if True, get Euler-Maruyama predictor for combinatorial complexes. Defaults to False.

  • d_min (Optional[int], optional) – minimum size of rank-2 cells (if combinatorial complexes). Defaults to None.

  • d_max (Optional[int], optional) – maximum size of rank-2 cells (if combinatorial complexes). Defaults to None.

update_fn(*args: Any, **kwargs: Any) Tuple[Tensor, Tensor][source]#

Update function for the Euler-Maruyama predictor.

update_fn_graph(x: Tensor, adj: Tensor, flags: Tensor, t: Tensor) Tuple[Tensor, Tensor][source]#

Update function for the Euler-Maruyama predictor for graphs.

Parameters:
  • x (torch.Tensor) – node features

  • adj (torch.Tensor) – adjacency matrix

  • flags (torch.Tensor) – flags

  • t (torch.Tensor) – timestep

Raises:

NotImplementedError – raise an error if the object to update is not recognized.

Returns:

updated tensor and mean

Return type:

Tuple[torch.Tensor, torch.Tensor]

update_fn_cc(x: Tensor, adj: Tensor, rank2: Tensor, flags: Tensor, t: Tensor) Tuple[Tensor, Tensor][source]#

Update function for the Euler-Maruyama predictor for combinatorial complexes.

Parameters:
  • x (torch.Tensor) – node features

  • adj (torch.Tensor) – adjacency matrix

  • rank2 (torch.Tensor) – rank-2 incidence matrix

  • flags (torch.Tensor) – flags

  • t (torch.Tensor) – timestep

Raises:

NotImplementedError – raise an error if the object to update is not recognized.

Returns:

updated tensor and mean

Return type:

Tuple[torch.Tensor, torch.Tensor]

class ccsd.src.solver.ReverseDiffusionPredictor(obj: str, sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor], probability_flow: bool = False, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None)[source]#

Bases: Predictor

Reverse diffusion predictor.

__init__(obj: str, sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor], probability_flow: bool = False, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None)[source]#

Initialize the Reverse Diffusion predictor.

Parameters:
  • obj (str) – object to update, either “x”, “adj” or “rank2”

  • sde (SDE) – the SDE to solve

  • score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]) – the score function

  • probability_flow (bool, optional) – if True, use probability flow sampling. Defaults to False.

  • is_cc (bool, optional) – if True, get Reverse Diffusion predictor for combinatorial complexes. Defaults to False.

  • d_min (Optional[int], optional) – minimum size of rank-2 cells (if combinatorial complexes). Defaults to None.

  • d_max (Optional[int], optional) – maximum size of rank-2 cells (if combinatorial complexes). Defaults to None.

update_fn(*args: Any, **kwargs: Any) Tuple[Tensor, Tensor][source]#

Update function for the Reverse Diffusion predictor.

Returns:

updated tensor and mean

Return type:

Tuple[torch.Tensor, torch.Tensor]

update_fn_graph(x: Tensor, adj: Tensor, flags: Tensor, t: Tensor) Tuple[Tensor, Tensor][source]#

Update function for the Reverse Diffusion predictor for graphs.

Parameters:
  • x (torch.Tensor) – node features

  • adj (torch.Tensor) – adjacency matrix

  • flags (torch.Tensor) – flags

  • t (torch.Tensor) – timestep

Raises:

NotImplementedError – raise an error if the object to update is not recognized.

Returns:

updated tensor and mean

Return type:

Tuple[torch.Tensor, torch.Tensor]

update_fn_cc(x: Tensor, adj: Tensor, rank2: Tensor, flags: Tensor, t: Tensor) Tuple[Tensor, Tensor][source]#

Update function for the Reverse Diffusion predictor for combinatorial complexes.

Parameters:
  • x (torch.Tensor) – node features

  • adj (torch.Tensor) – adjacency matrix

  • rank2 (torch.Tensor) – rank-2 incidence matrix

  • flags (torch.Tensor) – flags

  • t (torch.Tensor) – timestep

Raises:

NotImplementedError – raise an error if the object to update is not recognized.

Returns:

updated tensor and mean

Return type:

Tuple[torch.Tensor, torch.Tensor]

class ccsd.src.solver.NoneCorrector(obj: str, sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], snr: float, scale_eps: float, n_steps: int, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None)[source]#

Bases: Corrector

An empty corrector that does nothing.

__init__(obj: str, sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], snr: float, scale_eps: float, n_steps: int, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None)[source]#

Initialize the NoneCorrector (an empty corrector that does nothing).

Parameters:
  • obj (str) – object to update, either “x” or “adj”

  • sde (SDE) – the SDE to solve. UNUSED HERE

  • score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]) – the score function. UNUSED HERE

  • snr (float) – signal-to-noise ratio. UNUSED HERE

  • scale_eps (float) – scale of the noise. UNUSED HERE

  • n_steps (int) – number of steps to take. UNUSED HERE

  • is_cc (bool, optional) – if True, get NoneCorrector for combinatorial complexes. Defaults to False.

  • d_min (Optional[int], optional) – minimum size of rank-2 cells (if combinatorial complexes). Defaults to None.

  • d_max (Optional[int], optional) – maximum size of rank-2 cells (if combinatorial complexes). Defaults to None.

update_fn(*args: Any, **kwargs: Any) Tuple[Tensor, Tensor][source]#

Update function for the NoneCorrector.

Returns:

updated tensor and mean

Return type:

Tuple[torch.Tensor, torch.Tensor]

update_fn_graph(x: Tensor, adj: Tensor, flags: Tensor, t: Tensor) Tuple[Tensor, Tensor][source]#

Update function for the NoneCorrector for graphs.

Parameters:
  • x (torch.Tensor) – node features

  • adj (torch.Tensor) – adjacency matrix

  • flags (torch.Tensor) – flags

  • t (torch.Tensor) – timestep

Raises:

NotImplementedError – raise an error if the object to update is not recognized.

Returns:

updated tensor and mean

Return type:

Tuple[torch.Tensor, torch.Tensor]

update_fn_cc(x: Tensor, adj: Tensor, rank2: Tensor, flags: Tensor, t: Tensor) Tuple[Tensor, Tensor][source]#

Update function for the NoneCorrector for combinatorial complexes.

Parameters:
  • x (torch.Tensor) – node features

  • adj (torch.Tensor) – adjacency matrix

  • rank2 (torch.Tensor) – rank-2 incidence matrix

  • flags (torch.Tensor) – flags

  • t (torch.Tensor) – timestep

Raises:

NotImplementedError – raise an error if the object to update is not recognized.

Returns:

updated tensor and mean

Return type:

Tuple[torch.Tensor, torch.Tensor]

class ccsd.src.solver.LangevinCorrector(obj: str, sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], snr: float, scale_eps: float, n_steps: int, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None)[source]#

Bases: Corrector

Langevin corrector.

__init__(obj: str, sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor] | Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor], snr: float, scale_eps: float, n_steps: int, is_cc: bool = False, d_min: int | None = None, d_max: int | None = None)[source]#

Initialize the Langevin corrector.

Parameters:
  • obj (str) – object to update, either “x”, “adj” or “rank2”

  • sde (SDE) – the SDE to solve

  • score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]) – the score function

  • snr (float) – signal-to-noise ratio

  • scale_eps (float) – scale of the noise

  • n_steps (int) – number of steps to take

  • is_cc (bool, optional) – if True, get Langevin corrector for combinatorial complexes. Defaults to False.

  • d_min (Optional[int], optional) – minimum size of rank-2 cells (if combinatorial complexes). Defaults to None.

  • d_max (Optional[int], optional) – maximum size of rank-2 cells (if combinatorial complexes). Defaults to None.

update_fn(*args: Any, **kwargs: Any) Tuple[Tensor, Tensor][source]#

Update function for the Langevin corrector.

Returns:

updated tensor and mean

Return type:

Tuple[torch.Tensor, torch.Tensor]

update_fn_graph(x: Tensor, adj: Tensor, flags: Tensor, t: Tensor) Tuple[Tensor, Tensor][source]#

Update function for the Langevin corrector for graphs.

Parameters:
  • x (torch.Tensor) – node features

  • adj (torch.Tensor) – adjacency matrix

  • flags (torch.Tensor) – flags

  • t (torch.Tensor) – timestep

Raises:

NotImplementedError – raise an error if the object to update is not recognized.

Returns:

updated tensor and mean

Return type:

Tuple[torch.Tensor, torch.Tensor]

update_fn_cc(x: Tensor, adj: Tensor, rank2: Tensor, flags: Tensor, t: Tensor) Tuple[Tensor, Tensor][source]#

Update function for the Langevin corrector for combinatorial complexes.

Parameters:
  • x (torch.Tensor) – node features

  • adj (torch.Tensor) – adjacency matrix

  • rank2 (torch.Tensor) – rank-2 incidence matrix

  • flags (torch.Tensor) – flags

  • t (torch.Tensor) – timestep

Raises:

NotImplementedError – raise an error if the object to update is not recognized.

Returns:

updated tensor and mean

Return type:

Tuple[torch.Tensor, torch.Tensor]

ccsd.src.solver.get_predictor(predictor: str) Predictor[source]#

Get the predictor function.

Parameters:

predictor (str) – the predictor to use. Select from [Reverse, Euler].

Raises:

NotImplementedError – raise an error if the predictor is not recognized.

Returns:

the predictor function.

Return type:

Predictor

ccsd.src.solver.get_corrector(corrector: str) Corrector[source]#

Get the corrector function.

Parameters:

corrector (str) – the corrector to use. Select from [Langevin, None].

Raises:

NotImplementedError – raise an error if the corrector is not recognized.

Returns:

the corrector function.

Return type:

Corrector

ccsd.src.solver.get_pc_sampler(sde_x: SDE, sde_adj: SDE, shape_x: Sequence[int], shape_adj: Sequence[int], predictor: str = 'Euler', corrector: str = 'None', snr: float = 0.1, scale_eps: float = 1.0, n_steps: int = 1, probability_flow: bool = False, continuous: bool = False, denoise: bool = True, eps: float = 0.001, device: str = 'cuda', is_cc: bool = False, sde_rank2: SDE | None = None, shape_rank2: Sequence[int] | None = None, d_min: int | None = None, d_max: int | None = None) Callable[[Module, Module, Tensor], Tuple[Tensor, Tensor, float, List[List[Tensor]]]] | Callable[[Module, Module, Module, Tensor], Tuple[Tensor, Tensor, Tensor, float, List[List[Tensor]]]][source]#

Returns a PC sampler.

Parameters:
  • sde_x (SDE) – SDE for the node features

  • sde_adj (SDE) – SDE for the adjacency matrix

  • shape_x (Sequence[int]) – shape of the node features

  • shape_adj (Sequence[int]) – shape of the adjacency matrix

  • predictor (str, optional) – predictor function. Select from [Euler, Reverse]. Defaults to “Euler”.

  • corrector (str, optional) – corrector function. Select from [Langevin, None]. Defaults to “None”.

  • snr (float, optional) – signal-to-noise ratio. Defaults to 0.1.

  • scale_eps (float, optional) – scale of the noise. Defaults to 1.0.

  • n_steps (int, optional) – number of steps to take. Defaults to 1.

  • probability_flow (bool, optional) – if True, use probability flow sampling. Defaults to False.

  • continuous (bool, optional) – if True, use continuous-time SDEs, for the score function. Defaults to False.

  • denoise (bool, optional) – if True, use denoising diffusion (returns the mean of the reverse SDE). Defaults to True.

  • eps (float, optional) – epsilon for the reverse-time SDE. Defaults to 1e-3.

  • device (str, optional) – device to use. Defaults to “cuda”.

  • is_cc (bool, optional) – if True, get PC sampler for combinatorial complexes. Defaults to False.

  • sde_rank2 (Optional[SDE], optional) – SDE for the higher-order features. Defaults to None.

  • shape_rank2 (Optional[Sequence[int]], optional) – shape of the higher-order features. Defaults to None.

  • d_min (Optional[int], optional) – minimum size of rank-2 cells (if combinatorial complexes). Defaults to None.

  • d_max (Optional[int], optional) – maximum size of rank-2 cells (if combinatorial complexes). Defaults to None.

Returns:

PC sampler

Return type:

Union[Callable[[torch.nn.Module, torch.nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]]]], Callable[[torch.nn.Module, torch.nn.Module, torch.nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]]]]]

ccsd.src.solver.S4_solver(sde_x: SDE, sde_adj: SDE, shape_x: Sequence[int], shape_adj: Sequence[int], predictor: str = 'None', corrector: str = 'None', snr: float = 0.1, scale_eps: float = 1.0, n_steps: int = 1, probability_flow: bool = False, continuous: bool = False, denoise: bool = True, eps: float = 0.001, device: str = 'cuda', is_cc: bool = False, sde_rank2: SDE | None = None, shape_rank2: Sequence[int] | None = None, d_min: int | None = None, d_max: int | None = None) Callable[[Module, Module, Tensor], Tuple[Tensor, Tensor, float, List[List[Tensor]]]] | Callable[[Module, Module, Module, Tensor], Tuple[Tensor, Tensor, Tensor, float, List[List[Tensor]]]][source]#

Returns a S4 sampler.

Parameters:
  • sde_x (SDE) – SDE for the node features

  • sde_adj (SDE) – SDE for the adjacency matrix

  • shape_x (Sequence[int]) – shape of the node features

  • shape_adj (Sequence[int]) – shape of the adjacency matrix

  • predictor (str, optional) – predictor function. UNUSED HERE. Select from [Euler, Reverse]. Defaults to “None”.

  • corrector (str, optional) – corrector function. UNUSED HERE. Select from [Langevin, None]. Defaults to “None”.

  • snr (float, optional) – signal-to-noise ratio. Defaults to 0.1.

  • scale_eps (float, optional) – scale of the noise. Defaults to 1.0.

  • n_steps (int, optional) – number of steps to take. UNUSED HERE. Defaults to 1.

  • probability_flow (bool, optional) – if True, use probability flow sampling. UNUSED HERE. Defaults to False.

  • continuous (bool, optional) – if True, use continuous-time SDEs, for the score function. Defaults to False.

  • denoise (bool, optional) – if True, use denoising diffusion (returns the mean of the reverse SDE). Defaults to True.

  • eps (float, optional) – epsilon for the reverse-time SDE. Defaults to 1e-3.

  • device (str, optional) – device to use. Defaults to “cuda”.

  • is_cc (bool, optional) – if True, get S4 sampler for combinatorial complexes. Defaults to False.

  • sde_rank2 (Optional[SDE], optional) – SDE for the higher-order features. Defaults to None.

  • shape_rank2 (Optional[Sequence[int]], optional) – shape of the higher-order features. Defaults to None.

  • d_min (Optional[int], optional) – minimum size of rank-2 cells (if combinatorial complexes). Defaults to None.

  • d_max (Optional[int], optional) – maximum size of rank-2 cells (if combinatorial complexes). Defaults to None.

Returns:

S4 sampler

Return type:

Union[Callable[[torch.nn.Module, torch.nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]]]], Callable[[torch.nn.Module, torch.nn.Module, torch.nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]]]]]

losses.py: Loss functions for training the SDEs.

Adapted from Jo, J. & al (2022), except for get_sde_loss_fn_cc, almost left untouched.

ccsd.src.losses.get_score_fn(sde: SDE, model: Module, train: bool = True, continuous: bool = True) Callable[[Tensor, Tensor, Tensor | None, Tensor], Tensor][source]#

Return the score function for the SDE and the model.

Parameters:
  • sde (SDE) – Stochastic Differential Equation (SDE)

  • model (torch.nn.Module) – neural network model that predicts the score

  • train (bool, optional) – whether or not we train the model. Defaults to True.

  • continuous (bool, optional) – if the SDE is continuous (discrete NOT IMPLEMENTED HERE). Defaults to True.

Raises:

NotImplementedError – raise an error if the SDE is not implemented

Returns:

score function

Return type:

Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], float], torch.Tensor]

ccsd.src.losses.get_score_fn_cc(sde: SDE, model: Module, train: bool = True, continuous: bool = True) Callable[[Tensor, Tensor, Tensor, Tensor | None, Tensor], Tensor][source]#

Return the score function for the SDE and the model.

Parameters:
  • sde (SDE) – Stochastic Differential Equation (SDE)

  • model (torch.nn.Module) – neural network model that predicts the score

  • train (bool, optional) – whether or not we train the model. Defaults to True.

  • continuous (bool, optional) – if the SDE is continuous (discrete NOT IMPLEMENTED HERE). Defaults to True.

Raises:

NotImplementedError – raise an error if the SDE is not implemented

Returns:

score function

Return type:

Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], float], torch.Tensor]

ccsd.src.losses.get_sde_loss_fn(sde_x: SDE, sde_adj: SDE, train: bool = True, reduce_mean: bool = False, continuous: bool = True, likelihood_weighting: bool = False, eps: float = 1e-05) Callable[[Module, Module, Tensor, Tensor], Tuple[Tensor, Tensor]][source]#

Return the loss function for the SDEs with specific parameters.

Parameters:
  • sde_x (SDE) – SDE for node features

  • sde_adj (SDE) – SDE for adjacency matrix

  • train (bool, optional) – whether or not we are training the model. Defaults to True.

  • reduce_mean (bool, optional) – if True, we reduce the loss by first taking the mean along the last axis. Defaults to False.

  • continuous (bool, optional) – if the SDE is continuous. Defaults to True.

  • likelihood_weighting (bool, optional) – if True, weight the loss with standard deviations. Defaults to False.

  • eps (float, optional) – parameter for sampling time. Defaults to 1e-5.

Returns:

loss function

Return type:

Callable[[torch.nn.Module, torch.nn.Module, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]

ccsd.src.losses.get_sde_loss_fn_cc(sde_x: SDE, sde_adj: SDE, sde_rank2: SDE, d_min: int, d_max: int, train: bool = True, reduce_mean: bool = False, continuous: bool = True, likelihood_weighting: bool = False, eps: float = 1e-05) Callable[[Module, Module, Module, Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]][source]#

Return the loss function for the SDEs with specific parameters.

Parameters:
  • sde_x (SDE) – SDE for node features

  • sde_adj (SDE) – SDE for adjacency matrix

  • sde_rank2 (SDE) – SDE for rank-2 incidence tensor

  • d_min (int) – minimum size of the rank-2 cells

  • d_max (int) – maximum size of the rank-2 cells

  • train (bool, optional) – whether or not we are training the model. Defaults to True.

  • reduce_mean (bool, optional) – if True, we reduce the loss by first taking the mean along the last axis. Defaults to False.

  • continuous (bool, optional) – if the SDE is continuous. Defaults to True.

  • likelihood_weighting (bool, optional) – if True, weight the loss with standard deviations. Defaults to False.

  • eps (float, optional) – parameter for sampling time. Defaults to 1e-5.

Returns:

loss function

Return type:

Callable[[torch.nn.Module, torch.nn.Module, torch.nn.Module, torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]