Source code for ccsd.src.models.hodge_attention

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""attention.py: HodgeAttention and HodgeAdjAttentionLayer classes for the ScoreNetwork models.
"""

import math
from typing import Callable, Optional, Tuple, Union

import torch
import torch.nn.functional as torch_func

from ccsd.src.models.hodge_layers import DenseHCNConv
from ccsd.src.models.layers import MLP
from ccsd.src.utils.cc_utils import get_rank2_dim, mask_hodge_adjs, mask_rank2


[docs] class HodgeAttention(torch.nn.Module): """Hodge Combinatorial Complexes Multi-Head (HCCMH) Attention layer Used in the HodgeAdjAttentionLayer below """
[docs] def __init__( self, in_dim: int, attn_dim: int, out_dim: int, num_heads: int = 4, conv: str = "HCN", ) -> None: """Initialize the HodgeAttention layer Args: in_dim (int): input dimension attn_dim (int): attention dimension out_dim (int): output dimension num_heads (int, optional): number of attention heads. Defaults to 4. conv (str, optional): type of convolutional layer, choose from [HCN, MLP]. Defaults to "HCN". """ super(HodgeAttention, self).__init__() # Intialize the parameters of the HodgeAttention layer self.num_heads = num_heads self.in_dim = in_dim self.attn_dim = attn_dim self.out_dim = out_dim self.conv = conv # Initialize the GNNs self.ccnn_q, self.ccnn_k, self.ccnn_v = self.get_ccnn( self.in_dim, self.attn_dim, self.out_dim, self.conv ) self.activation = torch.tanh # Reset the parameters of the GNNs self.reset_parameters()
def __repr__(self) -> str: """Representation of the HodgeAttention layer Returns: str: representation of the HodgeAttention layer """ return ( f"{self.__class__.__name__}(" f"in_dim={self.in_dim}, " f"attn_dim={self.attn_dim}, " f"out_dim={self.out_dim}, " f"num_heads={self.num_heads}, " f"conv={self.conv})" )
[docs] def reset_parameters(self) -> None: """Reset the parameters of the HodgeAttention layer""" self.ccnn_q.reset_parameters() self.ccnn_k.reset_parameters()
# self.ccnn_v.reset_parameters()
[docs] def forward( self, hodge_adj: torch.Tensor, rank2: torch.Tensor, flags: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass of the HodgeAttention layer. Returns the value, attention matrix. Args: hodge_adj (torch.Tensor): hodge adjacency matrix rank2 (torch.Tensor): rank-2 incidence matrix flags (torch.Tensor): node flags attention_mask (Optional[torch.Tensor], optional): attention mask for the attention matrix. Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor]: value, attention matrix """ if self.conv == "HCN": Q = self.ccnn_q(hodge_adj, rank2) K = self.ccnn_k(hodge_adj, rank2) else: Q = self.ccnn_q(hodge_adj) K = self.ccnn_k(hodge_adj) # V = self.ccnn_v(hodge_adj, rank2) V = torch.bmm(hodge_adj, self.ccnn_v(rank2)) dim_split = self.attn_dim // self.num_heads Q_ = torch.cat(Q.split(dim_split, 2), 0) K_ = torch.cat(K.split(dim_split, 2), 0) if attention_mask is not None: # duplicate the attention mask for each head attention_mask = torch.cat( [attention_mask for _ in range(self.num_heads)], 0 ) # compute the attention score attention_score = Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.out_dim) # mask the attention score A = self.activation(attention_mask + attention_score) else: A = self.activation(Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.out_dim)) # A: (B x num_heads) x (NC2) x (NC2) A = A.view(-1, *hodge_adj.shape) A = A.mean(dim=0) A = (A + A.transpose(-1, -2)) / 2 # symmetrize the attention matrix return V, A
[docs] def get_ccnn( self, in_dim: int, attn_dim: int, out_dim: int, conv: str = "HCN" ) -> Tuple[ Union[ Callable[[torch.Tensor, torch.Tensor], torch.Tensor], Callable[[torch.Tensor], torch.Tensor], ], Union[ Callable[[torch.Tensor, torch.Tensor], torch.Tensor], Callable[[torch.Tensor], torch.Tensor], ], Callable[[torch.Tensor, torch.Tensor], torch.Tensor], ]: """Initialize the HCNs Args: in_dim (int): input dimension attn_dim (int): attention dimension out_dim (int): output dimension conv (str, optional): type of convolutional layer, choose from [HCN, MLP]. Defaults to "HCN". Raises: NotImplementedError: raise an error if the convolutional layer is not implemented Returns: Tuple[Union[Callable[[torch.Tensor, torch.Tensor], torch.Tensor], Callable[[torch.Tensor], torch.Tensor]], Union[Callable[[torch.Tensor, torch.Tensor], torch.Tensor], Callable[[torch.Tensor], torch.Tensor]], Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]: three GNNs, one for the query, one for the key, and one for the value """ if conv == "HCN": ccnn_q = DenseHCNConv(in_dim, attn_dim) ccnn_k = DenseHCNConv(in_dim, attn_dim) # ccnn_v = DenseHCNConv(in_dim, out_dim) ccnn_v = torch.nn.Identity() return ccnn_q, ccnn_k, ccnn_v elif conv == "MLP": num_layers = 2 ccnn_q = MLP( num_layers, in_dim, 2 * attn_dim, attn_dim, activate_func=torch.tanh ) ccnn_k = MLP( num_layers, in_dim, 2 * attn_dim, attn_dim, activate_func=torch.tanh ) # ccnn_v = DenseHCNConv(in_dim, out_dim) ccnn_v = torch.nn.Identity() return ccnn_q, ccnn_k, ccnn_v else: raise NotImplementedError(f"Convolution layer {conv} not implemented.")
[docs] class HodgeAdjAttentionLayer(torch.nn.Module): """HodgeAdjAttentionLayer for ScoreNetworkA_CC"""
[docs] def __init__( self, num_linears: int, input_dim: int, attn_dim: int, conv_output_dim: int, N: int, d_min: int, d_max: int, num_heads: int = 4, conv: str = "GCN", use_bn: bool = False, ) -> None: """Initialize the HodgeAdjAttentionLayer Args: num_linears (int): number of linear layers in the MLPs input_dim (int): input dimension of the HodgeAdjAttentionLayer (also number of HodgeAttention) attn_dim (int): attention dimension conv_output_dim (int): output dimension of the MLP (output number of channels) N (int): maximum number of nodes d_min (int): minimum size of rank-2 cells d_max (int): maximum size of rank-2 cells num_heads (int, optional): number of heads for the Attention. Defaults to 4. conv (str, optional): type of convolutional layer, choose from [GCN, MLP]. Defaults to "GCN". use_bn (bool, optional): whether to use batch normalization in the MLP. Defaults to False. """ super(HodgeAdjAttentionLayer, self).__init__() # Define the parameters of the layer self.num_linears = num_linears self.input_dim = input_dim self.attn_dim = attn_dim self.conv_output_dim = conv_output_dim self.N = N self.d_min = d_min self.d_max = d_max self.K = get_rank2_dim(N, d_min, d_max)[1] # calculate K self.num_heads = num_heads self.conv = conv self.use_bn = use_bn # Define the layers self.attn = torch.nn.ModuleList() for _ in range(self.input_dim): self.attn.append( HodgeAttention( self.K, self.attn_dim, self.K, num_heads=self.num_heads, conv=self.conv, ) ) self.hidden_dim = 2 * max(self.input_dim, self.conv_output_dim) self.mlp_value = MLP( num_layers=self.num_linears, input_dim=self.input_dim, hidden_dim=self.hidden_dim, output_dim=1, use_bn=self.use_bn, activate_func=torch_func.elu, ) self.mlp_attention = MLP( num_layers=self.num_linears, input_dim=self.input_dim, hidden_dim=self.hidden_dim, output_dim=self.conv_output_dim, use_bn=self.use_bn, activate_func=torch_func.elu, ) # Reset the parameters self.reset_parameters()
def __repr__(self) -> str: """Representation of the HodgeAdjAttentionLayer Returns: str: representation of the HodgeAdjAttentionLayer """ return ( f"{self.__class__.__name__}(" f"input_dim={self.input_dim}, " f"attn_dim={self.attn_dim}, " f"conv_output_dim={self.conv_output_dim}, " f"num_heads={self.num_heads}, " f"conv={self.conv}, " f"hidden_dim={self.hidden_dim})" )
[docs] def reset_parameters(self) -> None: """Reset the parameters of the HodgeAdjAttentionLayer""" # Reset the MLPs self.mlp_value.reset_parameters() self.mlp_attention.reset_parameters() # Reset the attention layers for attn in self.attn: attn.reset_parameters()
[docs] def forward( self, hodge_adj: torch.Tensor, rank2: torch.Tensor, flags: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass for the HodgeAdjAttentionLayer. Returns a hodge adjacency matrix and a rank-2 incidence matrix. Args: hodge_adj (torch.Tensor): hodge adjacency matrix (B x C_i x (NC2) x (NC2)) C_i is the number of input channels rank2 (torch.Tensor): rank-2 incidence matrix (B x (NC2) x K) flags (Optional[torch.Tensor]): flags for the nodes Returns: Tuple[torch.Tensor, torch.Tensor]: hodge adjacency matrix and a rank-2 incidence matrix (B x C_o x (NC2) x (NC2)), (B x (NC2) x K) C_o is the number of output channels """ value_list = [] attention_list = [] for k in range(self.input_dim): _value, _attention = self.attn[k](hodge_adj[:, k, :, :], rank2, flags) value_list.append(_value.unsqueeze(-1)) attention_list.append(_attention.unsqueeze(-1)) hodge_adj_out = mask_hodge_adjs( self.mlp_attention(torch.cat(attention_list, dim=-1)).permute(0, 3, 1, 2), flags, ) hodge_adj_out = torch.tanh(hodge_adj_out) hodge_adj_out = hodge_adj_out + hodge_adj_out.transpose(-1, -2) _rank2 = self.mlp_value(torch.cat(value_list, dim=-1)).squeeze(-1) rank2_out = mask_rank2(_rank2, self.N, self.d_min, self.d_max, flags) return hodge_adj_out, rank2_out