Source code for ccsd.src.models.ScoreNetwork_A

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""ScoreNetwork_A.py: ScoreNetworkA and BaselineNetwork classes.
These are ScoreNetwork models for the adjacency matrix A.

Adapted from Jo, J. & al (2022)

Almost left untouched.
"""

from typing import Optional, Tuple

import torch
import torch.nn.functional as torch_func

from ccsd.src.models.attention import AttentionLayer
from ccsd.src.models.layers import MLP, DenseGCNConv
from ccsd.src.utils.cc_utils import default_mask
from ccsd.src.utils.graph_utils import (
    mask_adjs,
    mask_x,
    node_feature_to_matrix,
    pow_tensor,
)


[docs] class BaselineNetworkLayer(torch.nn.Module): """BaselineNetworkLayer that operates on tensors derived from an adjacency matrix A. Used in the BaselineNetwork model. """
[docs] def __init__( self, num_linears: int, conv_input_dim: int, conv_output_dim: int, input_dim: int, output_dim: int, use_bn: bool = False, ) -> None: """Initialize the BaselineNetworkLayer. Args: num_linears (int): number of linear layers in the MLP (except the first one) conv_input_dim (int): input dimension of the DenseGCNConv layers conv_output_dim (int): output dimension of the DenseGCNConv layers input_dim (int): number of DenseGCNConv layers (part of the input dimension of the final MLP) output_dim (int): output dimension of the final MLP use_bn (bool, optional): whether to use batch normalization in the MLP. Defaults to False. """ super(BaselineNetworkLayer, self).__init__() # Initialize the parameters and the layers self.use_bn = use_bn self.convs = torch.nn.ModuleList() for _ in range(input_dim): self.convs.append(DenseGCNConv(conv_input_dim, conv_output_dim)) self.hidden_dim = max(input_dim, output_dim) self.mlp_in_dim = input_dim + 2 * conv_output_dim self.mlp = MLP( num_linears, self.mlp_in_dim, self.hidden_dim, output_dim, use_bn=self.use_bn, activate_func=torch_func.elu, ) self.multi_channel = MLP( 2, input_dim * conv_output_dim, self.hidden_dim, conv_output_dim, use_bn=self.use_bn, activate_func=torch_func.elu, ) # Reset the parameters self.reset_parameters()
def __repr__(self) -> str: """Representation of the BaselineNetworkLayer. Returns: str: representation of the BaselineNetworkLayer """ return ( f"{self.__class__.__name__}(" f"num_linears={self.mlp.num_linears}, " f"conv_input_dim={self.convs[0].in_channels}, " f"conv_output_dim={self.convs[0].out_channels}, " f"input_dim={len(self.convs)}, " f"output_dim={self.mlp.out_dim}, " f"use_bn={self.use_bn})" )
[docs] def reset_parameters(self) -> None: """Reset the parameters of the BaselineNetworkLayer.""" # Reset the parameters of the DenseGCNConv layers, the MLPs, and the multi-channel MLP for conv in self.convs: conv.reset_parameters() self.mlp.reset_parameters() self.multi_channel.reset_parameters()
[docs] def forward( self, x: torch.Tensor, adj: torch.Tensor, flags: Optional[torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass of the BaselineNetworkLayer. Args: x (torch.Tensor): node feature matrix adj (torch.Tensor): adjacency matrix flags (Optional[torch.Tensor]): optional flags for the node features Returns: Tuple[torch.Tensor, torch.Tensor]: node feature matrix and adjacency matrix """ # Apply all the DenseGCNConv layers x_list = [] for k in range(len(self.convs)): _x = self.convs[k](x, adj[:, k, :, :]) x_list.append(_x) # Concatenate the outputs of the DenseGCNConv layers, apply the multi-channel MLP, and mask the output x_out = mask_x(self.multi_channel(torch.cat(x_list, dim=-1)), flags) x_out = torch.tanh(x_out) # Convert the node feature matrix to a node feature adjacency tensor x_matrix = node_feature_to_matrix(x_out) # Concatenate the node feature adjacency tensor with the original adjacency tensor mlp_in = torch.cat([x_matrix, adj.permute(0, 2, 3, 1)], dim=-1) shape = mlp_in.shape # Apply the final MLP on the concatenated adjacency tensor mlp_out = self.mlp(mlp_in.view(-1, shape[-1])) _adj = mlp_out.view(shape[0], shape[1], shape[2], -1).permute(0, 3, 1, 2) _adj = _adj + _adj.transpose(-1, -2) # Mask the adjacency tensor adj_out = mask_adjs(_adj, flags) return x_out, adj_out
[docs] class BaselineNetwork(torch.nn.Module): """BaselineNetwork to calculate the score with respect to the adjacency matrix A."""
[docs] def __init__( self, max_feat_num: int, max_node_num: int, nhid: int, num_layers: int, num_linears: int, c_init: int, c_hid: int, c_final: int, adim: int, num_heads: int = 4, conv: str = "GCN", use_bn: bool = False, is_cc: bool = False, ) -> None: """Initialize the BaselineNetwork. Args: max_feat_num (int): maximum number of node features max_node_num (int): maximum number of nodes in the graphs nhid (int): number of hidden units in BaselineNetworkLayer layers num_layers (int): number of BaselineNetworkLayer layers num_linears (int): number of linear layers in the MLP of each BaselineNetworkLayer c_init (int): input dimension of the BaselineNetworkLayer (number of DenseGCNConv) Also the number of power iterations to "duplicate" the adjacency matrix as an input c_hid (int): number of hidden units in the MLP of each BaselineNetworkLayer c_final (int): output dimension of the MLP of the last BaselineNetworkLayer adim (int): UNUSED HERE. attention dimension (except for the first layer). num_heads (int, optional): UNUSED HERE. number of heads for the Attention. Defaults to 4. conv (str, optional): UNUSED HERE. type of convolutional layer, choose from [GCN, MLP]. Defaults to "GCN". use_bn (bool, optional): whether to use batch normalization in the MLP and the BaselineNetworkLayer(s). Defaults to False. is_cc (bool, optional): True if we generate combinatorial complexes. Defaults to False. """ super(BaselineNetwork, self).__init__() # Initialize the parameters self.max_feat_num = max_feat_num self.max_node_num = max_node_num self.nhid = nhid self.num_layers = num_layers self.num_linears = num_linears self.c_init = c_init self.c_hid = c_hid self.c_final = c_final self.use_bn = use_bn self.is_cc = is_cc # Initialize the layers self.layers = torch.nn.ModuleList() for k in range(self.num_layers): if not (k): # first layer self.layers.append( BaselineNetworkLayer( self.num_linears, self.max_feat_num, self.nhid, self.c_init, self.c_hid, self.use_bn, ) ) elif k == (self.num_layers - 1): # last layer self.layers.append( BaselineNetworkLayer( self.num_linears, self.nhid, self.nhid, self.c_hid, self.c_final, self.use_bn, ) ) else: # intermediate layers self.layers.append( BaselineNetworkLayer( self.num_linears, self.nhid, self.nhid, self.c_hid, self.c_hid, self.use_bn, ) ) # Initialize the final MLP self.fdim = self.c_hid * (self.num_layers - 1) + self.c_final + self.c_init self.final = MLP( num_layers=3, input_dim=self.fdim, hidden_dim=2 * self.fdim, output_dim=1, use_bn=self.use_bn, activate_func=torch_func.elu, ) # Initialize the mask device = "cuda:0" if torch.cuda.is_available() else "cpu" self.mask = default_mask(self.max_node_num, device) self.mask.unsqueeze_(0) # Pick the right forward function if not (self.is_cc): self.forward = self.forward_graph else: self.forward = self.forward_cc # Reset the parameters self.reset_parameters()
def __repr__(self) -> str: """Representation of the BaselineNetwork. Returns: str: representation of the BaselineNetwork """ return ( f"{self.__class__.__name__}(" f"max_feat_num={self.max_feat_num}, " f"max_node_num={self.max_node_num}, " f"nhid={self.nhid}, " f"num_layers={self.num_layers}, " f"num_linears={self.num_linears}, " f"c_init={self.c_init}, " f"c_hid={self.c_hid}, " f"c_final={self.c_final}, " f"use_bn={self.use_bn}, " f"is_cc={self.is_cc})" )
[docs] def reset_parameters(self) -> None: """Reset the parameters of the BaselineNetwork.""" # Reset the parameters of the BaselineNetworkLayer layers for layer in self.layers: layer.reset_parameters() # Reset the parameters of the final MLP self.final.reset_parameters()
[docs] def forward_graph( self, x: torch.Tensor, adj: torch.Tensor, flags: Optional[torch.Tensor] = None ) -> torch.Tensor: """Forward pass of the BaselineNetwork. Returns the score with respect to the adjacency matrix A. Args: x (torch.Tensor): node feature matrix adj (torch.Tensor): adjacency matrix flags (Optional[torch.Tensor], optional): optional flags for the score. Defaults to None. Returns: torch.Tensor: score with respect to the adjacency matrix A """ # Duplicate the adjacency matrix as an input by creating power tensors adjc = pow_tensor(adj, self.c_init) # Apply all the BaselineNetworkLayer layers adj_list = [adjc] for k in range(self.num_layers): x, adjc = self.layers[k](x, adjc, flags) adj_list.append(adjc) # Concatenate the output of the BaselineNetworkLayer layers (B x N x N x (c_init + c_hid * (num_layers - 1) + c_final) adjs = torch.cat(adj_list, dim=1).permute(0, 2, 3, 1) out_shape = adjs.shape[:-1] # B x N x N # Apply the final MLP on the concatenated adjacency tensor to compute the score score = self.final(adjs).view(*out_shape) # Mask the score self.mask = self.mask.to(score.device) score = score * self.mask # Mask the score with respect to the flags score = mask_adjs(score, flags) return score
[docs] def forward_cc( self, x: torch.Tensor, adj: torch.Tensor, rank2: torch.Tensor, flags: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass of the BaselineNetwork. Returns the score with respect to the adjacency matrix A. Args: x (torch.Tensor): node feature matrix adj (torch.Tensor): adjacency matrix rank2 (torch.Tensor): rank2 incidence matrix flags (Optional[torch.Tensor], optional): optional flags for the score. Defaults to None. Returns: torch.Tensor: score with respect to the adjacency matrix A """ return self.forward_graph(x, adj, flags)
[docs] class ScoreNetworkA(torch.nn.Module): """ScoreNetworkA to calculate the score with respect to the adjacency matrix A."""
[docs] def __init__( self, max_feat_num: int, max_node_num: int, nhid: int, num_layers: int, num_linears: int, c_init: int, c_hid: int, c_final: int, adim: int, num_heads: int = 4, conv: str = "GCN", use_bn: bool = False, is_cc: bool = False, ) -> None: """Initialize the ScoreNetworkA model. Args: max_feat_num (int): maximum number of node features max_node_num (int): maximum number of nodes in the graphs nhid (int): number of hidden units in AttentionLayer layers num_layers (int): number of AttentionLayer layers num_linears (int): number of linear layers in the MLP of each AttentionLayer c_init (int): input dimension of the AttentionLayer (number of DenseGCNConv) Also the number of power iterations to "duplicate" the adjacency matrix as an input c_hid (int): number of hidden units in the MLP of each AttentionLayer c_final (int): output dimension of the MLP of the last AttentionLayer adim (int): attention dimension (except for the first layer). num_heads (int, optional): number of heads for the Attention. Defaults to 4. conv (str, optional): type of convolutional layer, choose from [GCN, MLP]. Defaults to "GCN". use_bn (bool, optional): whether to use batch normalization in the MLP and the AttentionLayer(s). Defaults to False. is_cc (bool, optional): True if we generate combinatorial complexes. Defaults to False. """ super(ScoreNetworkA, self).__init__() # Initialize the parameters self.max_feat_num = max_feat_num self.max_node_num = max_node_num self.nhid = nhid self.num_layers = num_layers self.num_linears = num_linears self.c_init = c_init self.c_hid = c_hid self.c_final = c_final self.adim = adim self.num_heads = num_heads self.conv = conv self.use_bn = use_bn self.is_cc = is_cc # Initialize the layers self.layers = torch.nn.ModuleList() for k in range(self.num_layers): if not (k): # first layer self.layers.append( AttentionLayer( self.num_linears, self.max_feat_num, self.nhid, self.nhid, self.c_init, self.c_hid, self.num_heads, self.conv, self.use_bn, ) ) elif k == (self.num_layers - 1): # last layer self.layers.append( AttentionLayer( self.num_linears, self.nhid, self.adim, self.nhid, self.c_hid, self.c_final, self.num_heads, self.conv, self.use_bn, ) ) else: # intermediate layers self.layers.append( AttentionLayer( self.num_linears, self.nhid, self.adim, self.nhid, self.c_hid, self.c_hid, self.num_heads, self.conv, self.use_bn, ) ) # Initialize the final MLP self.fdim = self.c_hid * (self.num_layers - 1) + self.c_final + self.c_init self.final = MLP( num_layers=3, input_dim=self.fdim, hidden_dim=2 * self.fdim, output_dim=1, use_bn=self.use_bn, activate_func=torch_func.elu, ) # Initialize the mask device = "cuda:0" if torch.cuda.is_available() else "cpu" self.mask = default_mask(self.max_node_num, device) self.mask = self.mask.unsqueeze(0) # Pick the right forward function if not (self.is_cc): self.forward = self.forward_graph else: self.forward = self.forward_cc # Reset the parameters self.reset_parameters()
def __repr__(self) -> str: """Representation of the ScoreNetworkA model. Returns: str: representation of the ScoreNetworkA model """ return ( f"{self.__class__.__name__}(" f"max_feat_num={self.max_feat_num}, " f"max_node_num={self.max_node_num}, " f"nhid={self.nhid}, " f"num_layers={self.num_layers}, " f"num_linears={self.num_linears}, " f"c_init={self.c_init}, " f"c_hid={self.c_hid}, " f"c_final={self.c_final}, " f"adim={self.adim}, " f"num_heads={self.num_heads}, " f"conv={self.conv}, " f"use_bn={self.use_bn}, " f"is_cc={self.is_cc})" )
[docs] def reset_parameters(self) -> None: """Reset the parameters of the model.""" # Reset the parameters of the AttentionLayer layers for attn in self.layers: attn.reset_parameters() # Reset the parameters of the final MLP self.final.reset_parameters()
[docs] def forward_graph( self, x: torch.Tensor, adj: torch.Tensor, flags: Optional[torch.Tensor] = None ) -> torch.Tensor: """Forward pass of the ScoreNetworkA. Returns the score with respect to the adjacency matrix A. Args: x (torch.Tensor): node feature matrix adj (torch.Tensor): adjacency matrix flags (Optional[torch.Tensor], optional): optional flags for the score. Defaults to None. Returns: torch.Tensor: score with respect to the adjacency matrix A """ # Duplicate the adjacency matrix as an input by creating power tensors adjc = pow_tensor(adj, self.c_init) # Apply all the AttentionLayer layers adj_list = [adjc] for k in range(self.num_layers): x, adjc = self.layers[k](x, adjc, flags) adj_list.append(adjc) # Concatenate the output of the AttentionLayer layers (B x N x N x (c_init + c_hid * (num_layers - 1) + c_final) adjs = torch.cat(adj_list, dim=1).permute(0, 2, 3, 1) out_shape = adjs.shape[:-1] # B x N x N # Apply the final MLP on the concatenated adjacency tensor to compute the score score = self.final(adjs).view(*out_shape) # Mask the score self.mask = self.mask.to(score.device) score = score * self.mask # Mask the score with respect to the flags score = mask_adjs(score, flags) return score
[docs] def forward_cc( self, x: torch.Tensor, adj: torch.Tensor, rank2: torch.Tensor, flags: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass of the ScoreNetworkA. Returns the score with respect to the adjacency matrix A. Args: x (torch.Tensor): node feature matrix adj (torch.Tensor): adjacency matrix rank2 (torch.Tensor): rank2 incidence matrix flags (Optional[torch.Tensor], optional): optional flags for the score. Defaults to None. Returns: torch.Tensor: score with respect to the adjacency matrix A """ return self.forward_graph(x, adj, flags)