Source code for ccsd.src.models.ScoreNetwork_X

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""ScoreNetwork_X.py: ScoreNetworkX and ScoreNetworkX_GMH classes.
These are ScoreNetwork models for the node feature matrix X.

Adapted from Jo, J. & al (2022)

Almost left untouched.
"""

from typing import Optional

import torch
import torch.nn.functional as F

from ccsd.src.models.attention import AttentionLayer
from ccsd.src.models.layers import MLP, DenseGCNConv
from ccsd.src.utils.graph_utils import mask_x, pow_tensor


[docs] class ScoreNetworkX(torch.nn.Module): """ScoreNetworkX network model. Returns the score with respect to the node feature matrix X."""
[docs] def __init__( self, max_feat_num: int, depth: int, nhid: int, use_bn: bool = False, is_cc: bool = False, ) -> None: """Initialize ScoreNetworkX. Args: max_feat_num (int): maximum number of node features (input and output dimension of the network) depth (int): number of DenseGCNConv layers nhid (int): number of hidden units in DenseGCNConv layers use_bn (bool, optional): True if we use batch normalization in the MLP. Defaults to False. is_cc (bool, optional): True if we generate combinatorial complexes. Defaults to False. """ super(ScoreNetworkX, self).__init__() # Initialize parameters self.nfeat = max_feat_num self.depth = depth self.nhid = nhid self.use_bn = use_bn self.is_cc = is_cc # Initialize DenseGCNConv layers self.layers = torch.nn.ModuleList() for k in range(self.depth): if not (k): # first layer self.layers.append(DenseGCNConv(self.nfeat, self.nhid)) else: # other layers self.layers.append(DenseGCNConv(self.nhid, self.nhid)) # Final layer is an MLP on the concatenation of all layers' outputs self.fdim = self.nfeat + self.depth * self.nhid self.final = MLP( num_layers=3, input_dim=self.fdim, hidden_dim=2 * self.fdim, output_dim=self.nfeat, use_bn=self.use_bn, activate_func=F.elu, ) # Initialize activation function self.activation = torch.tanh # Pick the right forward function if not (self.is_cc): self.forward = self.forward_graph else: self.forward = self.forward_cc # Reset the parameters self.reset_parameters()
def __repr__(self) -> str: """String representation of the model. Returns: str: string representation of the model """ return f"{self.__class__.__name__}(depth={self.depth}, nhid={self.nhid}, use_bn={self.use_bn}, is_cc={self.is_cc})"
[docs] def reset_parameters(self) -> None: """Reset the parameters of the model.""" # Reset the parameters of the DenseGCNConv layers for layer in self.layers: layer.reset_parameters() # Reset the parameters of the final MLP layer self.final.reset_parameters()
[docs] def forward_graph( self, x: torch.Tensor, adj: torch.Tensor, flags: Optional[torch.Tensor] ) -> torch.Tensor: """Forward pass of the ScoreNetworkX model. Args: x (torch.Tensor): node feature matrix (B x N x F) adj (torch.Tensor): adjacency matrix (B x N x N) flags (Optional[torch.Tensor]): optional mask matrix (B x N x 1) Returns: torch.Tensor: score with respect to the node feature matrix (B x N x F) """ # Apply all the DenseGCN layers x_list = [x] for k in range(self.depth): x = self.layers[k](x, adj) x = self.activation(x) x_list.append(x) # Concatenate all the layers' outputs (B x N x (F + num_layers x H)) # B batch size, N max number of nodes, F number of features, H number of hidden units xs = torch.cat(x_list, dim=-1) out_shape = (adj.shape[0], adj.shape[1], -1) # Apply the final MLP layer x = self.final(xs).view(*out_shape) # Mask the output x = mask_x(x, flags) return x
[docs] def forward_cc( self, x: torch.Tensor, adj: torch.Tensor, rank2: torch.Tensor, flags: Optional[torch.Tensor], ) -> torch.Tensor: """Forward pass of the ScoreNetworkX model. Args: x (torch.Tensor): node feature matrix (B x N x F) adj (torch.Tensor): adjacency matrix (B x N x N) rank2 (torch.Tensor): rank2 incidence matrix (B x (NC2) x K) flags (Optional[torch.Tensor]): optional mask matrix (B x N x 1) Returns: torch.Tensor: score with respect to the node feature matrix (B x N x F) """ return self.forward_graph(x, adj, flags)
[docs] class ScoreNetworkX_GMH(torch.nn.Module): """ScoreNetworkX network model. Returns the score with respect to the node feature matrix X. """
[docs] def __init__( self, max_feat_num: int, depth: int, nhid: int, num_linears: int, c_init: int, c_hid: int, c_final: int, adim: int, num_heads: int = 4, conv: str = "GCN", use_bn: bool = False, is_cc: bool = False, ) -> None: """Initialize ScoreNetworkX_GMH. Args: max_feat_num (int): maximum number of node features (input and output dimension of the network) depth (int): number of DenseGCNConv layers nhid (int): number of hidden units in DenseGCNConv layers num_linears (int): number of linear layers in AttentionLayer c_init (int): input dimension of the AttentionLayer (number of attention) Also the number of power iterations to "duplicate" the adjacency matrix as an input c_hid (int): output dimension of the MLP in the AttentionLayer c_final (int): output dimension of the MLP in the AttentionLayer for the last layer of this model adim (int): attention dimension (except for the first layer) num_heads (int, optional): number of heads for the Attention. Defaults to 4. conv (str, optional): type of convolutional layer, choose from [GCN, MLP]. Defaults to "GCN". use_bn (bool, optional): True if we use batch normalization in the MLP and the AttentionLayer(s). Defaults to False. is_cc (bool, optional): True if we generate combinatorial complexes. Defaults to False. """ super().__init__() # Initialize parameters self.depth = depth self.c_init = c_init self.use_bn = use_bn self.is_cc = is_cc # Initialize AttentionLayer layers self.layers = torch.nn.ModuleList() for k in range(self.depth): if not (k): # first layer self.layers.append( AttentionLayer( num_linears, max_feat_num, nhid, nhid, c_init, c_hid, num_heads, conv, self.use_bn, ) ) elif k == (self.depth - 1): # last layer self.layers.append( AttentionLayer( num_linears, nhid, adim, nhid, c_hid, c_final, num_heads, conv, self.use_bn, ) ) else: # other layers self.layers.append( AttentionLayer( num_linears, nhid, adim, nhid, c_hid, c_hid, num_heads, conv, self.use_bn, ) ) # Final layer is an MLP on the concatenation of all layers' outputs fdim = max_feat_num + depth * nhid self.final = MLP( num_layers=3, input_dim=fdim, hidden_dim=2 * fdim, output_dim=max_feat_num, use_bn=self.use_bn, activate_func=F.elu, ) # Initialize activation function self.activation = torch.tanh # Pick the right forward function if not (self.is_cc): self.forward = self.forward_graph else: self.forward = self.forward_cc # Reset the parameters self.reset_parameters()
def __repr__(self) -> str: """String representation of the ScoreNetworkX_GMH model. Returns: str: string representation """ return f"{self.__class__.__name__}(depth={self.depth}, c_init={self.c_init}, use_bn={self.use_bn}, is_cc={self.is_cc})"
[docs] def reset_parameters(self) -> None: """Reset the parameters of the ScoreNetworkX_GMH model.""" # Reset the parameters of the AttentionLayer layers for attn in self.layers: attn.reset_parameters() # Reset the parameters of the final MLP self.final.reset_parameters()
[docs] def forward_graph( self, x: torch.Tensor, adj: torch.Tensor, flags: Optional[torch.Tensor] ) -> torch.Tensor: """Forward pass of the ScoreNetworkX_GMH model. Args: x (torch.Tensor): node feature matrix (B x N x F) adj (torch.Tensor): adjacency matrix (B x N x N) flags (Optional[torch.Tensor]): optional mask matrix (B x N x 1) Returns: torch.Tensor: score with respect to the node feature matrix (B x N x F) """ # Duplicate the adjacency matrix as an input by creating power tensors adjc = pow_tensor(adj, self.c_init) # Apply all the AttentionLayer layers x_list = [x] for k in range(self.depth): x, adjc = self.layers[k](x, adjc, flags) x = self.activation(x) x_list.append(x) # Concatenate all the layers' outputs (B x N x (F + num_layers x H)) # B batch size, N max number of nodes, F number of features, H number of hidden units xs = torch.cat(x_list, dim=-1) # B x N x (F + num_layers x H) out_shape = (adj.shape[0], adj.shape[1], -1) x = self.final(xs).view(*out_shape) # Mask the output x = mask_x(x, flags) return x
[docs] def forward_cc( self, x: torch.Tensor, adj: torch.Tensor, rank2: torch.Tensor, flags: Optional[torch.Tensor], ) -> torch.Tensor: """Forward pass of the ScoreNetworkX_GMH model. Args: x (torch.Tensor): node feature matrix (B x N x F) adj (torch.Tensor): adjacency matrix (B x N x N) rank2 (torch.Tensor): rank2 incidence matrix (B x (NC2) x K) flags (Optional[torch.Tensor]): optional mask matrix (B x N x 1) Returns: torch.Tensor: score with respect to the node feature matrix (B x N x F) """ return self.forward_graph(x, adj, flags)