Source code for ccsd.src.solver

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""solver.py: Contains the SDEs solvers, and the predictor and corrector algorithms.
The correctors consist of leveraging score-based MCMC methods.

Adapted from Jo, J. & al (2022) for Combinatorial Complexes and higher-order domain compatibility.
"""

import abc
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from tqdm import trange

from ccsd.src.losses import get_score_fn, get_score_fn_cc
from ccsd.src.sde import SDE, VPSDE, subVPSDE
from ccsd.src.utils.cc_utils import gen_noise_rank2, mask_rank2
from ccsd.src.utils.graph_utils import gen_noise, mask_adjs, mask_x
from ccsd.src.utils.models_utils import get_ones


[docs] class Predictor(abc.ABC): """Abstract class for a predictor algorithm."""
[docs] def __init__( self, sde: SDE, score_fn: Union[ Callable[ [torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor, ], Callable[ [ torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, ], torch.Tensor, ], ], probability_flow: bool = False, is_cc: bool = False, d_min: Optional[int] = None, d_max: Optional[int] = None, ) -> None: """Initialize the Predictor. Args: sde (SDE): the SDE to solve score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]): the score function probability_flow (bool, optional): if True, use probability flow sampling. Defaults to False. is_cc (bool, optional): if True, get predictor for combinatorial complexes. Defaults to False. d_min (Optional[int], optional): minimum size of rank-2 cells (if combinatorial complexes). Defaults to None. d_max (Optional[int], optional): maximum size of rank-2 cells (if combinatorial complexes). Defaults to None. """ super().__init__() # Initialize the Predictor self.sde = sde # Compute the reverse SDE/ODE self.rsde = sde.reverse(score_fn, probability_flow, is_cc) self.score_fn = score_fn self.is_cc = is_cc self.d_min = d_min self.d_max = d_max
[docs] @abc.abstractmethod def update_fn(self, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor]: """Update function for the predictor class Args: x (torch.Tensor): tensor adj (torch.Tensor): adjacency matrix. Optional. rank2 (torch.Tensor): rank-2 tensor. Optional. flags (torch.Tensor): flags t (torch.Tensor): timestep Returns: Tuple[torch.Tensor, torch.Tensor]: updated tensor and mean """ pass
[docs] class Corrector(abc.ABC): """Abstract class for a corrector algorithm."""
[docs] def __init__( self, sde: SDE, score_fn: Union[ Callable[ [torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor, ], Callable[ [ torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, ], torch.Tensor, ], ], snr: float, scale_eps: float, n_steps: int, is_cc: bool = False, d_min: Optional[int] = None, d_max: Optional[int] = None, ) -> None: """Initialize the Corrector. Args: sde (SDE): the SDE to solve score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]): the score function snr (float): signal-to-noise ratio scale_eps (float): scale of the noise n_steps (int): number of steps is_cc (bool, optional): if True, get corrector for combinatorial complexes. Defaults to False. d_min (Optional[int], optional): minimum size of rank-2 cells (if combinatorial complexes). Defaults to None. d_max (Optional[int], optional): maximum size of rank-2 cells (if combinatorial complexes). Defaults to None. """ super().__init__() # Initialize the Corrector self.sde = sde self.score_fn = score_fn self.snr = snr self.scale_eps = scale_eps self.n_steps = n_steps self.is_cc = is_cc self.d_min = d_min self.d_max = d_max
[docs] @abc.abstractmethod def update_fn(self, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor]: """Update function for the corrector class. Args: x (torch.Tensor): tensor adj (torch.Tensor): adjacency matrix. Optional. rank2 (torch.Tensor): rank-2 tensor. Optional. flags (torch.Tensor): flags t (torch.Tensor): timestep Returns: Tuple[torch.Tensor, torch.Tensor]: updated tensor and mean """ pass
[docs] class EulerMaruyamaPredictor(Predictor): """Euler-Maruyama predictor."""
[docs] def __init__( self, obj: str, sde: SDE, score_fn: Union[ Callable[ [torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor, ], Callable[ [ torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, ], torch.Tensor, ], ], probability_flow: bool = False, is_cc: bool = False, d_min: Optional[int] = None, d_max: Optional[int] = None, ) -> None: """Initialize the Euler-Maruyama predictor. Args: obj (str): object to update, either "x", "adj", or "rank2" sde (SDE): the SDE to solve score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]): the score function probability_flow (bool, optional): if True, use probability flow sampling. Defaults to False. is_cc (bool, optional): if True, get Euler-Maruyama predictor for combinatorial complexes. Defaults to False. d_min (Optional[int], optional): minimum size of rank-2 cells (if combinatorial complexes). Defaults to None. d_max (Optional[int], optional): maximum size of rank-2 cells (if combinatorial complexes). Defaults to None. """ super().__init__(sde, score_fn, probability_flow, is_cc, d_min, d_max) self.obj = obj
def __repr__(self) -> str: """Representation of the Euler-Maruyama predictor.""" return f"{self.__class__.__name__}(obj={self.obj}, sde={self.sde.__class__.__name__}, probability_flow={self.probability_flow}, is_cc={self.is_cc}, d_min={self.d_min}, d_max={self.d_max})"
[docs] def update_fn(self, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor]: """Update function for the Euler-Maruyama predictor.""" if self.is_cc: return self.update_fn_cc(*args, **kwargs) else: return self.update_fn_graph(*args, **kwargs)
[docs] def update_fn_graph( self, x: torch.Tensor, adj: torch.Tensor, flags: torch.Tensor, t: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Update function for the Euler-Maruyama predictor for graphs. Args: x (torch.Tensor): node features adj (torch.Tensor): adjacency matrix flags (torch.Tensor): flags t (torch.Tensor): timestep Raises: NotImplementedError: raise an error if the object to update is not recognized. Returns: Tuple[torch.Tensor, torch.Tensor]: updated tensor and mean """ dt = -1.0 / self.rsde.N # Reverse SDE for the node features if self.obj == "x": z = gen_noise(x, flags, sym=False) drift, diffusion = self.rsde.sde(x, adj, flags, t, is_adj=False) x_mean = x + drift * dt x = x_mean + diffusion[:, None, None] * np.sqrt(-dt) * z return x, x_mean # Reverse SDE for the adjacency matrix elif self.obj == "adj": z = gen_noise(adj, flags, sym=True) drift, diffusion = self.rsde.sde(x, adj, flags, t, is_adj=True) adj_mean = adj + drift * dt adj = adj_mean + diffusion[:, None, None] * np.sqrt(-dt) * z return adj, adj_mean # Raise error if obj is not recognized else: raise NotImplementedError( f"Object {self.obj} not yet supported. Select from [x, adj]." )
[docs] def update_fn_cc( self, x: torch.Tensor, adj: torch.Tensor, rank2: torch.Tensor, flags: torch.Tensor, t: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Update function for the Euler-Maruyama predictor for combinatorial complexes. Args: x (torch.Tensor): node features adj (torch.Tensor): adjacency matrix rank2 (torch.Tensor): rank-2 incidence matrix flags (torch.Tensor): flags t (torch.Tensor): timestep Raises: NotImplementedError: raise an error if the object to update is not recognized. Returns: Tuple[torch.Tensor, torch.Tensor]: updated tensor and mean """ dt = -1.0 / self.rsde.N # Reverse SDE for the node features if self.obj == "x": z = gen_noise(x, flags, sym=False) drift, diffusion = self.rsde.sde( x, adj, rank2, flags, t, is_adj=False, is_rank2=False ) x_mean = x + drift * dt x = x_mean + diffusion[:, None, None] * np.sqrt(-dt) * z return x, x_mean # Reverse SDE for the adjacency matrix elif self.obj == "adj": z = gen_noise(adj, flags, sym=True) drift, diffusion = self.rsde.sde( x, adj, rank2, flags, t, is_adj=True, is_rank2=False ) adj_mean = adj + drift * dt adj = adj_mean + diffusion[:, None, None] * np.sqrt(-dt) * z return adj, adj_mean # Reverse SDE for the rank2 incidence matrix elif self.obj == "rank2": z = gen_noise_rank2(rank2, adj.shape[1], self.d_min, self.d_max, flags) drift, diffusion = self.rsde.sde( x, adj, rank2, flags, t, is_adj=False, is_rank2=True ) rank2_mean = rank2 + drift * dt rank2 = rank2_mean + diffusion[:, None, None] * np.sqrt(-dt) * z return rank2, rank2_mean # Raise error if obj is not recognized else: raise NotImplementedError( f"Object {self.obj} not yet supported. Select from [x, adj, rank2]." )
[docs] class ReverseDiffusionPredictor(Predictor): """Reverse diffusion predictor."""
[docs] def __init__( self, obj: str, sde: SDE, score_fn: Union[ Callable[ [torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor, ], Callable[ [torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor, ], ], probability_flow: bool = False, is_cc: bool = False, d_min: Optional[int] = None, d_max: Optional[int] = None, ): """Initialize the Reverse Diffusion predictor. Args: obj (str): object to update, either "x", "adj" or "rank2" sde (SDE): the SDE to solve score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]): the score function probability_flow (bool, optional): if True, use probability flow sampling. Defaults to False. is_cc (bool, optional): if True, get Reverse Diffusion predictor for combinatorial complexes. Defaults to False. d_min (Optional[int], optional): minimum size of rank-2 cells (if combinatorial complexes). Defaults to None. d_max (Optional[int], optional): maximum size of rank-2 cells (if combinatorial complexes). Defaults to None. """ super().__init__(sde, score_fn, probability_flow, is_cc, d_min, d_max) self.obj = obj
def __repr__(self) -> str: """Representation of the Reverse Diffusion predictor.""" return f"{self.__class__.__name__}(obj={self.obj}, sde={self.sde.__class__.__name__}, probability_flow={self.probability_flow}, is_cc={self.is_cc}, d_min={self.d_min}, d_max={self.d_max})"
[docs] def update_fn(self, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor]: """Update function for the Reverse Diffusion predictor. Returns: Tuple[torch.Tensor, torch.Tensor]: updated tensor and mean """ if self.is_cc: return self.update_fn_cc(*args, **kwargs) else: return self.update_fn_graph(*args, **kwargs)
[docs] def update_fn_graph( self, x: torch.Tensor, adj: torch.Tensor, flags: torch.Tensor, t: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Update function for the Reverse Diffusion predictor for graphs. Args: x (torch.Tensor): node features adj (torch.Tensor): adjacency matrix flags (torch.Tensor): flags t (torch.Tensor): timestep Raises: NotImplementedError: raise an error if the object to update is not recognized. Returns: Tuple[torch.Tensor, torch.Tensor]: updated tensor and mean """ # Reverse SDE for the node features if self.obj == "x": f, G = self.rsde.discretize(x, adj, flags, t, is_adj=False) z = gen_noise(x, flags, sym=False) x_mean = x - f x = x_mean + G[:, None, None] * z return x, x_mean # Reverse SDE for the adjacency matrix elif self.obj == "adj": f, G = self.rsde.discretize(x, adj, flags, t, is_adj=True) z = gen_noise(adj, flags) adj_mean = adj - f adj = adj_mean + G[:, None, None] * z return adj, adj_mean # Raise error if obj is not recognized else: raise NotImplementedError( f"Object {self.obj} not yet supported. Select from [x, adj]." )
[docs] def update_fn_cc( self, x: torch.Tensor, adj: torch.Tensor, rank2: torch.Tensor, flags: torch.Tensor, t: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Update function for the Reverse Diffusion predictor for combinatorial complexes. Args: x (torch.Tensor): node features adj (torch.Tensor): adjacency matrix rank2 (torch.Tensor): rank-2 incidence matrix flags (torch.Tensor): flags t (torch.Tensor): timestep Raises: NotImplementedError: raise an error if the object to update is not recognized. Returns: Tuple[torch.Tensor, torch.Tensor]: updated tensor and mean """ # Reverse SDE for the node features if self.obj == "x": f, G = self.rsde.discretize( x, adj, rank2, flags, t, is_adj=False, is_rank2=False ) z = gen_noise(x, flags, sym=False) x_mean = x - f x = x_mean + G[:, None, None] * z return x, x_mean # Reverse SDE for the adjacency matrix elif self.obj == "adj": f, G = self.rsde.discretize( x, adj, rank2, flags, t, is_adj=True, is_rank2=False ) z = gen_noise(adj, flags) adj_mean = adj - f adj = adj_mean + G[:, None, None] * z return adj, adj_mean # Reverse SDE for the rank2 incidence matrix elif self.obj == "rank2": f, G = self.rsde.discretize( x, adj, rank2, flags, t, is_adj=False, is_rank2=True ) z = gen_noise_rank2(rank2, adj.shape[1], self.d_min, self.d_max, flags) rank2_mean = rank2 - f rank2 = rank2_mean + G[:, None, None] * z return rank2, rank2_mean # Raise error if obj is not recognized else: raise NotImplementedError( f"Object {self.obj} not yet supported. Select from [x, adj, rank2]." )
[docs] class NoneCorrector(Corrector): """An empty corrector that does nothing."""
[docs] def __init__( self, obj: str, sde: SDE, score_fn: Union[ Callable[ [torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor, ], Callable[ [ torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, ], torch.Tensor, ], ], snr: float, scale_eps: float, n_steps: int, is_cc: bool = False, d_min: Optional[int] = None, d_max: Optional[int] = None, ): """Initialize the NoneCorrector (an empty corrector that does nothing). Args: obj (str): object to update, either "x" or "adj" sde (SDE): the SDE to solve. UNUSED HERE score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]): the score function. UNUSED HERE snr (float): signal-to-noise ratio. UNUSED HERE scale_eps (float): scale of the noise. UNUSED HERE n_steps (int): number of steps to take. UNUSED HERE is_cc (bool, optional): if True, get NoneCorrector for combinatorial complexes. Defaults to False. d_min (Optional[int], optional): minimum size of rank-2 cells (if combinatorial complexes). Defaults to None. d_max (Optional[int], optional): maximum size of rank-2 cells (if combinatorial complexes). Defaults to None. """ super().__init__(sde, score_fn, snr, scale_eps, n_steps, is_cc, d_min, d_max) self.obj = obj
def __repr__(self) -> str: """Representation of the None corrector.""" return f"{self.__class__.__name__}(obj={self.obj}, sde={self.sde.__class__.__name__}, snr={self.snr}, scale_eps={self.scale_eps}, n_steps={self.n_steps}, is_cc={self.is_cc}, d_min={self.d_min}, d_max={self.d_max})"
[docs] def update_fn(self, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor]: """Update function for the NoneCorrector. Returns: Tuple[torch.Tensor, torch.Tensor]: updated tensor and mean """ if self.is_cc: return self.update_fn_cc(*args, **kwargs) else: return self.update_fn_graph(*args, **kwargs)
[docs] def update_fn_graph( self, x: torch.Tensor, adj: torch.Tensor, flags: torch.Tensor, t: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Update function for the NoneCorrector for graphs. Args: x (torch.Tensor): node features adj (torch.Tensor): adjacency matrix flags (torch.Tensor): flags t (torch.Tensor): timestep Raises: NotImplementedError: raise an error if the object to update is not recognized. Returns: Tuple[torch.Tensor, torch.Tensor]: updated tensor and mean """ # Reverse SDE for the node features if self.obj == "x": return x, x # Reverse SDE for the adjacency matrix elif self.obj == "adj": return adj, adj # Raise error if obj is not recognized else: raise NotImplementedError( f"Object {self.obj} not yet supported. Select from [x, adj]." )
[docs] def update_fn_cc( self, x: torch.Tensor, adj: torch.Tensor, rank2: torch.Tensor, flags: torch.Tensor, t: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Update function for the NoneCorrector for combinatorial complexes. Args: x (torch.Tensor): node features adj (torch.Tensor): adjacency matrix rank2 (torch.Tensor): rank-2 incidence matrix flags (torch.Tensor): flags t (torch.Tensor): timestep Raises: NotImplementedError: raise an error if the object to update is not recognized. Returns: Tuple[torch.Tensor, torch.Tensor]: updated tensor and mean """ # Reverse SDE for the node features if self.obj == "x": return x, x # Reverse SDE for the adjacency matrix elif self.obj == "adj": return adj, adj # Reverse SDE for the rank2 incidence matrix elif self.obj == "rank2": return rank2, rank2 # Raise error if obj is not recognized else: raise NotImplementedError( f"Object {self.obj} not yet supported. Select from [x, adj, rank2]." )
[docs] class LangevinCorrector(Corrector): """Langevin corrector."""
[docs] def __init__( self, obj: str, sde: SDE, score_fn: Union[ Callable[ [torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor, ], Callable[ [ torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, ], torch.Tensor, ], ], snr: float, scale_eps: float, n_steps: int, is_cc: bool = False, d_min: Optional[int] = None, d_max: Optional[int] = None, ): """Initialize the Langevin corrector. Args: obj (str): object to update, either "x", "adj" or "rank2" sde (SDE): the SDE to solve score_fn (Union[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor], Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor], torch.Tensor]]): the score function snr (float): signal-to-noise ratio scale_eps (float): scale of the noise n_steps (int): number of steps to take is_cc (bool, optional): if True, get Langevin corrector for combinatorial complexes. Defaults to False. d_min (Optional[int], optional): minimum size of rank-2 cells (if combinatorial complexes). Defaults to None. d_max (Optional[int], optional): maximum size of rank-2 cells (if combinatorial complexes). Defaults to None. """ super().__init__(sde, score_fn, snr, scale_eps, n_steps, is_cc, d_min, d_max) self.obj = obj
def __repr__(self) -> str: """Representation of the Langevin corrector.""" return f"{self.__class__.__name__}(obj={self.obj}, sde={self.sde.__class__.__name__}, snr={self.snr}, scale_eps={self.scale_eps}, n_steps={self.n_steps}, is_cc={self.is_cc}, d_min={self.d_min}, d_max={self.d_max})"
[docs] def update_fn(self, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor]: """Update function for the Langevin corrector. Returns: Tuple[torch.Tensor, torch.Tensor]: updated tensor and mean """ if self.is_cc: return self.update_fn_cc(*args, **kwargs) else: return self.update_fn_graph(*args, **kwargs)
[docs] def update_fn_graph( self, x: torch.Tensor, adj: torch.Tensor, flags: torch.Tensor, t: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Update function for the Langevin corrector for graphs. Args: x (torch.Tensor): node features adj (torch.Tensor): adjacency matrix flags (torch.Tensor): flags t (torch.Tensor): timestep Raises: NotImplementedError: raise an error if the object to update is not recognized. Returns: Tuple[torch.Tensor, torch.Tensor]: updated tensor and mean """ sde = self.sde score_fn = self.score_fn n_steps = self.n_steps target_snr = self.snr seps = self.scale_eps if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE): timestep = (t * (sde.N - 1) / sde.T).long() alpha = sde.alphas.to(t.device)[timestep] else: alpha = get_ones(t.shape, t.device) # Reverse SDE for the node features if self.obj == "x": for _ in range(n_steps): grad = score_fn(x, adj, flags, t) noise = gen_noise(x, flags, sym=False) grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() noise_norm = torch.norm( noise.reshape(noise.shape[0], -1), dim=-1 ).mean() step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha x_mean = x + step_size[:, None, None] * grad x = x_mean + torch.sqrt(step_size * 2)[:, None, None] * noise * seps return x, x_mean # Reverse SDE for the adjacency matrix elif self.obj == "adj": for _ in range(n_steps): grad = score_fn(x, adj, flags, t) noise = gen_noise(adj, flags) grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() noise_norm = torch.norm( noise.reshape(noise.shape[0], -1), dim=-1 ).mean() step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha adj_mean = adj + step_size[:, None, None] * grad adj = adj_mean + torch.sqrt(step_size * 2)[:, None, None] * noise * seps return adj, adj_mean else: raise NotImplementedError( f"Object {self.obj} not yet supported. Select from [x, adj]." )
[docs] def update_fn_cc( self, x: torch.Tensor, adj: torch.Tensor, rank2: torch.Tensor, flags: torch.Tensor, t: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Update function for the Langevin corrector for combinatorial complexes. Args: x (torch.Tensor): node features adj (torch.Tensor): adjacency matrix rank2 (torch.Tensor): rank-2 incidence matrix flags (torch.Tensor): flags t (torch.Tensor): timestep Raises: NotImplementedError: raise an error if the object to update is not recognized. Returns: Tuple[torch.Tensor, torch.Tensor]: updated tensor and mean """ sde = self.sde score_fn = self.score_fn n_steps = self.n_steps target_snr = self.snr seps = self.scale_eps if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE): timestep = (t * (sde.N - 1) / sde.T).long() alpha = sde.alphas.to(t.device)[timestep] else: alpha = get_ones(t.shape, t.device) # Reverse SDE for the node features if self.obj == "x": for _ in range(n_steps): grad = score_fn(x, adj, rank2, flags, t) noise = gen_noise(x, flags, sym=False) grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() noise_norm = torch.norm( noise.reshape(noise.shape[0], -1), dim=-1 ).mean() step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha x_mean = x + step_size[:, None, None] * grad x = x_mean + torch.sqrt(step_size * 2)[:, None, None] * noise * seps return x, x_mean # Reverse SDE for the adjacency matrix elif self.obj == "adj": for _ in range(n_steps): grad = score_fn(x, adj, rank2, flags, t) noise = gen_noise(adj, flags) grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() noise_norm = torch.norm( noise.reshape(noise.shape[0], -1), dim=-1 ).mean() step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha adj_mean = adj + step_size[:, None, None] * grad adj = adj_mean + torch.sqrt(step_size * 2)[:, None, None] * noise * seps return adj, adj_mean # Reverse SDE for the rank2 incidence matrix elif self.obj == "rank2": for _ in range(n_steps): grad = score_fn(x, adj, rank2, flags, t) noise = gen_noise_rank2( rank2, adj.shape[1], self.d_min, self.d_max, flags ) grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() noise_norm = torch.norm( noise.reshape(noise.shape[0], -1), dim=-1 ).mean() step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha rank2_mean = rank2 + step_size[:, None, None] * grad rank2 = ( rank2_mean + torch.sqrt(step_size * 2)[:, None, None] * noise * seps ) return rank2, rank2_mean else: raise NotImplementedError( f"Object {self.obj} not yet supported. Select from [x, adj, rank2]." )
[docs] def get_predictor(predictor: str) -> Predictor: """Get the predictor function. Args: predictor (str): the predictor to use. Select from [Reverse, Euler]. Raises: NotImplementedError: raise an error if the predictor is not recognized. Returns: Predictor: the predictor function. """ if predictor == "Reverse": predictor_fn = ReverseDiffusionPredictor elif predictor == "Euler": predictor_fn = EulerMaruyamaPredictor else: raise NotImplementedError( f"Predictor {predictor} not yet supported. Select from [Reverse, Euler]." ) return predictor_fn
[docs] def get_corrector(corrector: str) -> Corrector: """Get the corrector function. Args: corrector (str): the corrector to use. Select from [Langevin, None]. Raises: NotImplementedError: raise an error if the corrector is not recognized. Returns: Corrector: the corrector function. """ if corrector == "Langevin": corrector_fn = LangevinCorrector elif corrector == "None": corrector_fn = NoneCorrector else: raise NotImplementedError( f"Corrector {corrector} not yet supported. Select from [Langevin, None]." ) return corrector_fn
[docs] def get_pc_sampler( sde_x: SDE, sde_adj: SDE, shape_x: Sequence[int], shape_adj: Sequence[int], predictor: str = "Euler", corrector: str = "None", snr: float = 0.1, scale_eps: float = 1.0, n_steps: int = 1, probability_flow: bool = False, continuous: bool = False, denoise: bool = True, eps: float = 1e-3, device: str = "cuda", is_cc: bool = False, sde_rank2: Optional[SDE] = None, shape_rank2: Optional[Sequence[int]] = None, d_min: Optional[int] = None, d_max: Optional[int] = None, ) -> Union[ Callable[ [torch.nn.Module, torch.nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]]], ], Callable[ [torch.nn.Module, torch.nn.Module, torch.nn.Module, torch.Tensor], Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]] ], ], ]: """Returns a PC sampler. Args: sde_x (SDE): SDE for the node features sde_adj (SDE): SDE for the adjacency matrix shape_x (Sequence[int]): shape of the node features shape_adj (Sequence[int]): shape of the adjacency matrix predictor (str, optional): predictor function. Select from [Euler, Reverse]. Defaults to "Euler". corrector (str, optional): corrector function. Select from [Langevin, None]. Defaults to "None". snr (float, optional): signal-to-noise ratio. Defaults to 0.1. scale_eps (float, optional): scale of the noise. Defaults to 1.0. n_steps (int, optional): number of steps to take. Defaults to 1. probability_flow (bool, optional): if True, use probability flow sampling. Defaults to False. continuous (bool, optional): if True, use continuous-time SDEs, for the score function. Defaults to False. denoise (bool, optional): if True, use denoising diffusion (returns the mean of the reverse SDE). Defaults to True. eps (float, optional): epsilon for the reverse-time SDE. Defaults to 1e-3. device (str, optional): device to use. Defaults to "cuda". is_cc (bool, optional): if True, get PC sampler for combinatorial complexes. Defaults to False. sde_rank2 (Optional[SDE], optional): SDE for the higher-order features. Defaults to None. shape_rank2 (Optional[Sequence[int]], optional): shape of the higher-order features. Defaults to None. d_min (Optional[int], optional): minimum size of rank-2 cells (if combinatorial complexes). Defaults to None. d_max (Optional[int], optional): maximum size of rank-2 cells (if combinatorial complexes). Defaults to None. Returns: Union[Callable[[torch.nn.Module, torch.nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]]]], Callable[[torch.nn.Module, torch.nn.Module, torch.nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]]]]]: PC sampler """ if not (is_cc): def pc_sampler( model_x: torch.nn.Module, model_adj: torch.nn.Module, init_flags: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]]]: """PC sampler: sample from the model. Args: model_x (torch.nn.Module): model for the node features model_adj (torch.nn.Module): model for the adjacency matrix init_flags (torch.Tensor): initial flags Returns: Tuple[torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]]]: node features, adjacency matrix, timestep, one complete diffusion trajectory """ # Get score functions score_fn_x = get_score_fn( sde_x, model_x, train=False, continuous=continuous ) score_fn_adj = get_score_fn( sde_adj, model_adj, train=False, continuous=continuous ) # Get predictor and corrector functions predictor_fn = get_predictor(predictor) corrector_fn = get_corrector(corrector) # Evaluate the predictor and corrector predictor_obj_x = predictor_fn("x", sde_x, score_fn_x, probability_flow) corrector_obj_x = corrector_fn( "x", sde_x, score_fn_x, snr, scale_eps, n_steps ) predictor_obj_adj = predictor_fn( "adj", sde_adj, score_fn_adj, probability_flow ) corrector_obj_adj = corrector_fn( "adj", sde_adj, score_fn_adj, snr, scale_eps, n_steps ) # One complete diffusion trajectory diff_traj = [] with torch.no_grad(): # Initial sample x = sde_x.prior_sampling(shape_x).to(device) adj = sde_adj.prior_sampling_sym(shape_adj).to(device) flags = init_flags # Mask the initial sample x = mask_x(x, flags) adj = mask_adjs(adj, flags) diff_steps = sde_adj.N timesteps = torch.linspace(sde_adj.T, eps, diff_steps, device=device) # Reverse diffusion process for i in trange( 0, (diff_steps), desc="[Sampling]", position=1, leave=False ): t = timesteps[i] vec_t = get_ones((shape_adj[0],), device=t.device) * t _x = x x, x_mean = corrector_obj_x.update_fn(x, adj, flags, vec_t) adj, adj_mean = corrector_obj_adj.update_fn(_x, adj, flags, vec_t) _x = x x, x_mean = predictor_obj_x.update_fn(x, adj, flags, vec_t) adj, adj_mean = predictor_obj_adj.update_fn(_x, adj, flags, vec_t) # Add diffusion trajectory if denoise: diff_traj.append( [x_mean[0].detach().clone(), adj_mean[0].detach().clone()] ) else: diff_traj.append( [x[0].detach().clone(), adj[0].detach().clone()] ) print(" ") return ( (x_mean if denoise else x), (adj_mean if denoise else adj), diff_steps * (n_steps + 1), diff_traj, ) else: def pc_sampler( model_x: torch.nn.Module, model_adj: torch.nn.Module, model_rank2: torch.nn.Module, init_flags: torch.Tensor, ) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]] ]: """PC sampler: sample from the model. Args: model_x (torch.nn.Module): model for the node features model_adj (torch.nn.Module): model for the adjacency matrix model_rank2 (torch.nn.Module): model for the higher-order features (rank2 incidence matrix) init_flags (torch.Tensor): initial flags Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]]]: node features, adjacency matrix, rank2 incidence matrix, timestep, one complete diffusion trajectory """ # Get score functions score_fn_x = get_score_fn_cc( sde_x, model_x, train=False, continuous=continuous ) score_fn_adj = get_score_fn_cc( sde_adj, model_adj, train=False, continuous=continuous ) score_fn_rank2 = get_score_fn_cc( sde_rank2, model_rank2, train=False, continuous=continuous ) # Get predictor and corrector functions predictor_fn = get_predictor(predictor) corrector_fn = get_corrector(corrector) # Evaluate the predictor and corrector predictor_obj_x = predictor_fn( "x", sde_x, score_fn_x, probability_flow, is_cc=True, d_min=d_min, d_max=d_max, ) corrector_obj_x = corrector_fn( "x", sde_x, score_fn_x, snr, scale_eps, n_steps, is_cc=True, d_min=d_min, d_max=d_max, ) predictor_obj_adj = predictor_fn( "adj", sde_adj, score_fn_adj, probability_flow, is_cc=True, d_min=d_min, d_max=d_max, ) corrector_obj_adj = corrector_fn( "adj", sde_adj, score_fn_adj, snr, scale_eps, n_steps, is_cc=True, d_min=d_min, d_max=d_max, ) predictor_obj_rank2 = predictor_fn( "rank2", sde_rank2, score_fn_rank2, probability_flow, is_cc=True, d_min=d_min, d_max=d_max, ) corrector_obj_rank2 = corrector_fn( "rank2", sde_rank2, score_fn_rank2, snr, scale_eps, n_steps, is_cc=True, d_min=d_min, d_max=d_max, ) # One complete diffusion trajectory diff_traj = [] with torch.no_grad(): # Initial sample x = sde_x.prior_sampling(shape_x).to(device) adj = sde_adj.prior_sampling_sym(shape_adj).to(device) rank2 = sde_rank2.prior_sampling(shape_rank2).to(device) flags = init_flags # Mask the initial sample x = mask_x(x, flags) adj = mask_adjs(adj, flags) rank2 = mask_rank2(rank2, adj.shape[1], d_min, d_max, flags) diff_steps = sde_adj.N timesteps = torch.linspace(sde_adj.T, eps, diff_steps, device=device) # Reverse diffusion process for i in trange( 0, (diff_steps), desc="[Sampling]", position=1, leave=False ): t = timesteps[i] vec_t = get_ones((shape_adj[0],), device=t.device) * t _x = x _adj = adj x, x_mean = corrector_obj_x.update_fn(x, adj, rank2, flags, vec_t) adj, adj_mean = corrector_obj_adj.update_fn( _x, adj, rank2, flags, vec_t ) rank2, rank2_mean = corrector_obj_rank2.update_fn( _x, _adj, rank2, flags, vec_t ) _x = x _adj = adj x, x_mean = predictor_obj_x.update_fn(x, adj, rank2, flags, vec_t) adj, adj_mean = predictor_obj_adj.update_fn( _x, adj, rank2, flags, vec_t ) rank2, rank2_mean = predictor_obj_rank2.update_fn( _x, _adj, rank2, flags, vec_t ) # Add diffusion trajectory if denoise: diff_traj.append( [ x_mean[0].detach().clone(), adj_mean[0].detach().clone(), rank2_mean[0].detach().clone(), ] ) else: diff_traj.append( [ x[0].detach().clone(), adj[0].detach().clone(), rank2[0].detach().clone(), ] ) print(" ") return ( (x_mean if denoise else x), (adj_mean if denoise else adj), (rank2_mean if denoise else rank2), diff_steps * (n_steps + 1), diff_traj, ) return pc_sampler
[docs] def S4_solver( sde_x: SDE, sde_adj: SDE, shape_x: Sequence[int], shape_adj: Sequence[int], predictor: str = "None", corrector: str = "None", snr: float = 0.1, scale_eps: float = 1.0, n_steps: int = 1, probability_flow: bool = False, continuous: bool = False, denoise: bool = True, eps: float = 1e-3, device: str = "cuda", is_cc: bool = False, sde_rank2: Optional[SDE] = None, shape_rank2: Optional[Sequence[int]] = None, d_min: Optional[int] = None, d_max: Optional[int] = None, ) -> Union[ Callable[ [torch.nn.Module, torch.nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]]], ], Callable[ [torch.nn.Module, torch.nn.Module, torch.nn.Module, torch.Tensor], Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]] ], ], ]: """Returns a S4 sampler. Args: sde_x (SDE): SDE for the node features sde_adj (SDE): SDE for the adjacency matrix shape_x (Sequence[int]): shape of the node features shape_adj (Sequence[int]): shape of the adjacency matrix predictor (str, optional): predictor function. UNUSED HERE. Select from [Euler, Reverse]. Defaults to "None". corrector (str, optional): corrector function. UNUSED HERE. Select from [Langevin, None]. Defaults to "None". snr (float, optional): signal-to-noise ratio. Defaults to 0.1. scale_eps (float, optional): scale of the noise. Defaults to 1.0. n_steps (int, optional): number of steps to take. UNUSED HERE. Defaults to 1. probability_flow (bool, optional): if True, use probability flow sampling. UNUSED HERE. Defaults to False. continuous (bool, optional): if True, use continuous-time SDEs, for the score function. Defaults to False. denoise (bool, optional): if True, use denoising diffusion (returns the mean of the reverse SDE). Defaults to True. eps (float, optional): epsilon for the reverse-time SDE. Defaults to 1e-3. device (str, optional): device to use. Defaults to "cuda". is_cc (bool, optional): if True, get S4 sampler for combinatorial complexes. Defaults to False. sde_rank2 (Optional[SDE], optional): SDE for the higher-order features. Defaults to None. shape_rank2 (Optional[Sequence[int]], optional): shape of the higher-order features. Defaults to None. d_min (Optional[int], optional): minimum size of rank-2 cells (if combinatorial complexes). Defaults to None. d_max (Optional[int], optional): maximum size of rank-2 cells (if combinatorial complexes). Defaults to None. Returns: Union[Callable[[torch.nn.Module, torch.nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]]]], Callable[[torch.nn.Module, torch.nn.Module, torch.nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]]]]]: S4 sampler """ if not (is_cc): def s4_solver( model_x: torch.nn.Module, model_adj: torch.nn.Module, init_flags: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]]]: """S4 solver: sample from the model. Args: model_x (torch.nn.Module): model for the node features model_adj (torch.nn.Module): model for the adjacency matrix init_flags (torch.Tensor): initial flags Returns: Tuple[torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]]]: node features, adjacency matrix, timestep, one complete diffusion trajectory """ # Get score functions score_fn_x = get_score_fn( sde_x, model_x, train=False, continuous=continuous ) score_fn_adj = get_score_fn( sde_adj, model_adj, train=False, continuous=continuous ) # One complete diffusion trajectory diff_traj = [] with torch.no_grad(): # Initial sample x = sde_x.prior_sampling(shape_x).to(device) adj = sde_adj.prior_sampling_sym(shape_adj).to(device) flags = init_flags # Mask the initial sample x = mask_x(x, flags) adj = mask_adjs(adj, flags) diff_steps = sde_adj.N timesteps = torch.linspace(sde_adj.T, eps, diff_steps, device=device) dt = -1.0 / diff_steps # Reverse diffusion process for i in trange( 0, (diff_steps), desc="[Sampling]", position=1, leave=False ): t = timesteps[i] vec_t = get_ones((shape_adj[0],), device=t.device) * t vec_dt = get_ones((shape_adj[0],), device=t.device) * (dt / 2) # Score computation score_x = score_fn_x(x, adj, flags, vec_t) score_adj = score_fn_adj(x, adj, flags, vec_t) Sdrift_x = -sde_x.sde(x, vec_t)[1][:, None, None] ** 2 * score_x Sdrift_adj = ( -sde_adj.sde(adj, vec_t)[1][:, None, None] ** 2 * score_adj ) # Correction step timestep = (vec_t * (sde_x.N - 1) / sde_x.T).long() noise = gen_noise(x, flags, sym=False) grad_norm = torch.norm( score_x.reshape(score_x.shape[0], -1), dim=-1 ).mean() noise_norm = torch.norm( noise.reshape(noise.shape[0], -1), dim=-1 ).mean() if isinstance(sde_x, VPSDE): alpha = sde_x.alphas.to(vec_t.device)[timestep] else: alpha = get_ones(vec_t.shape, vec_t.device) step_size = (snr * noise_norm / grad_norm) ** 2 * 2 * alpha x_mean = x + step_size[:, None, None] * score_x x = ( x_mean + torch.sqrt(step_size * 2)[:, None, None] * noise * scale_eps ) noise = gen_noise(adj, flags) grad_norm = torch.norm( score_adj.reshape(score_adj.shape[0], -1), dim=-1 ).mean() noise_norm = torch.norm( noise.reshape(noise.shape[0], -1), dim=-1 ).mean() if isinstance(sde_adj, VPSDE): alpha = sde_adj.alphas.to(vec_t.device)[timestep] # VP else: alpha = get_ones(vec_t.shape, vec_t.device) # VE step_size = (snr * noise_norm / grad_norm) ** 2 * 2 * alpha adj_mean = adj + step_size[:, None, None] * score_adj adj = ( adj_mean + torch.sqrt(step_size * 2)[:, None, None] * noise * scale_eps ) # Prediction step x_mean = x adj_mean = adj mu_x, sigma_x = sde_x.transition(x, vec_t, vec_dt) mu_adj, sigma_adj = sde_adj.transition(adj, vec_t, vec_dt) x = mu_x + sigma_x[:, None, None] * gen_noise(x, flags, sym=False) adj = mu_adj + sigma_adj[:, None, None] * gen_noise(adj, flags) x = x + Sdrift_x * dt adj = adj + Sdrift_adj * dt mu_x, sigma_x = sde_x.transition(x, vec_t + vec_dt, vec_dt) mu_adj, sigma_adj = sde_adj.transition(adj, vec_t + vec_dt, vec_dt) x = mu_x + sigma_x[:, None, None] * gen_noise(x, flags, sym=False) adj = mu_adj + sigma_adj[:, None, None] * gen_noise(adj, flags) x_mean = mu_x adj_mean = mu_adj # Add diffusion trajectory if denoise: diff_traj.append( [x_mean[0].detach().clone(), adj_mean[0].detach().clone()] ) else: diff_traj.append( [x[0].detach().clone(), adj[0].detach().clone()] ) print(" ") return ( (x_mean if denoise else x), (adj_mean if denoise else adj), 0, diff_traj, ) else: def s4_solver( model_x: torch.nn.Module, model_adj: torch.nn.Module, model_rank2: torch.nn.Module, init_flags: torch.Tensor, ) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]] ]: """S4 solver: sample from the model. Args: model_x (torch.nn.Module): model for the node features model_adj (torch.nn.Module): model for the adjacency matrix model_rank2 (torch.nn.Module): model for the higher-order features (rank2 incidence matrix) init_flags (torch.Tensor): initial flags Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, float, List[List[torch.Tensor]]]: node features, adjacency matrix, incidence matrix, timestep, one complete diffusion trajectory """ # Get score functions score_fn_x = get_score_fn_cc( sde_x, model_x, train=False, continuous=continuous ) score_fn_adj = get_score_fn_cc( sde_adj, model_adj, train=False, continuous=continuous ) score_fn_rank2 = get_score_fn_cc( sde_rank2, model_rank2, train=False, continuous=continuous ) # One complete diffusion trajectory diff_traj = [] with torch.no_grad(): # Initial sample x = sde_x.prior_sampling(shape_x).to(device) adj = sde_adj.prior_sampling_sym(shape_adj).to(device) rank2 = sde_rank2.prior_sampling(shape_rank2).to(device) flags = init_flags # Mask the initial sample x = mask_x(x, flags) adj = mask_adjs(adj, flags) rank2 = mask_rank2(rank2, adj.shape[1], d_min, d_max, flags) diff_steps = sde_adj.N timesteps = torch.linspace(sde_adj.T, eps, diff_steps, device=device) dt = -1.0 / diff_steps # Reverse diffusion process for i in trange( 0, (diff_steps), desc="[Sampling]", position=1, leave=False ): t = timesteps[i] vec_t = get_ones((shape_adj[0],), device=t.device) * t vec_dt = get_ones((shape_adj[0],), device=t.device) * (dt / 2) # Score computation score_x = score_fn_x(x, adj, rank2, flags, vec_t) score_adj = score_fn_adj(x, adj, rank2, flags, vec_t) score_rank2 = score_fn_rank2(x, adj, rank2, flags, vec_t) Sdrift_x = -sde_x.sde(x, vec_t)[1][:, None, None] ** 2 * score_x Sdrift_adj = ( -sde_adj.sde(adj, vec_t)[1][:, None, None] ** 2 * score_adj ) Sdrift_rank2 = ( -sde_rank2.sde(rank2, vec_t)[1][:, None, None] ** 2 * score_rank2 ) # Correction step timestep = (vec_t * (sde_x.N - 1) / sde_x.T).long() noise = gen_noise(x, flags, sym=False) grad_norm = torch.norm( score_x.reshape(score_x.shape[0], -1), dim=-1 ).mean() noise_norm = torch.norm( noise.reshape(noise.shape[0], -1), dim=-1 ).mean() if isinstance(sde_x, VPSDE): alpha = sde_x.alphas.to(vec_t.device)[timestep] else: alpha = get_ones(vec_t.shape, vec_t.device) step_size = (snr * noise_norm / grad_norm) ** 2 * 2 * alpha x_mean = x + step_size[:, None, None] * score_x x = ( x_mean + torch.sqrt(step_size * 2)[:, None, None] * noise * scale_eps ) noise = gen_noise(adj, flags) grad_norm = torch.norm( score_adj.reshape(score_adj.shape[0], -1), dim=-1 ).mean() noise_norm = torch.norm( noise.reshape(noise.shape[0], -1), dim=-1 ).mean() if isinstance(sde_adj, VPSDE): alpha = sde_adj.alphas.to(vec_t.device)[timestep] # VP else: alpha = get_ones(vec_t.shape, vec_t.device) # VE step_size = (snr * noise_norm / grad_norm) ** 2 * 2 * alpha adj_mean = adj + step_size[:, None, None] * score_adj adj = ( adj_mean + torch.sqrt(step_size * 2)[:, None, None] * noise * scale_eps ) noise = gen_noise_rank2(rank2, adj.shape[1], d_min, d_max, flags) grad_norm = torch.norm( score_rank2.reshape(score_rank2.shape[0], -1), dim=-1 ).mean() noise_norm = torch.norm( noise.reshape(noise.shape[0], -1), dim=-1 ).mean() if isinstance(sde_rank2, VPSDE): alpha = sde_rank2.alphas.to(vec_t.device)[timestep] else: alpha = get_ones(vec_t.shape, vec_t.device) step_size = (snr * noise_norm / grad_norm) ** 2 * 2 * alpha rank2_mean = rank2 + step_size[:, None, None] * score_rank2 rank2 = ( rank2_mean + torch.sqrt(step_size * 2)[:, None, None] * noise * scale_eps ) # Prediction step x_mean = x adj_mean = adj rank2_mean = rank2 mu_x, sigma_x = sde_x.transition(x, vec_t, vec_dt) mu_adj, sigma_adj = sde_adj.transition(adj, vec_t, vec_dt) mu_rank2, sigma_rank2 = sde_rank2.transition(rank2, vec_t, vec_dt) x = mu_x + sigma_x[:, None, None] * gen_noise(x, flags, sym=False) adj = mu_adj + sigma_adj[:, None, None] * gen_noise(adj, flags) rank2 = mu_rank2 + sigma_rank2[:, None, None] * gen_noise_rank2( rank2, adj.shape[1], d_min, d_max, flags ) x = x + Sdrift_x * dt adj = adj + Sdrift_adj * dt rank2 = rank2 + Sdrift_rank2 * dt mu_x, sigma_x = sde_x.transition(x, vec_t + vec_dt, vec_dt) mu_adj, sigma_adj = sde_adj.transition(adj, vec_t + vec_dt, vec_dt) mu_rank2, sigma_rank2 = sde_rank2.transition( rank2, vec_t + vec_dt, vec_dt ) x = mu_x + sigma_x[:, None, None] * gen_noise(x, flags, sym=False) adj = mu_adj + sigma_adj[:, None, None] * gen_noise(adj, flags) rank2 = mu_rank2 + sigma_rank2[:, None, None] * gen_noise_rank2( rank2, adj.shape[1], d_min, d_max, flags ) x_mean = mu_x adj_mean = mu_adj rank2_mean = mu_rank2 # Add diffusion trajectory if denoise: diff_traj.append( [ x_mean[0].detach().clone(), adj_mean[0].detach().clone(), rank2_mean[0].detach().clone(), ] ) else: diff_traj.append( [ x[0].detach().clone(), adj[0].detach().clone(), rank2[0].detach().clone(), ] ) print(" ") return ( (x_mean if denoise else x), (adj_mean if denoise else adj), (rank2_mean if denoise else rank2), 0, diff_traj, ) return s4_solver