A Graph Neural Network (GNN) from Scratch for semi-supervised learning

Goals :

  1. An illustration of GNN for semi-supervised learning
  2. A toy applicaion to community detection
  3. A lecture on Structured Machine Learning

Author: Romain Raveaux (romain.raveaux@univ-tours.fr)

http://romain.raveaux.free.fr/ then teaching section

Install requirements

The lecture

The content of the is notebook is based on the following lectures : Supervised Machine Learning for structured input/output: Polytech, Tours

  • 1. Introduction to supervised Machine Learning: A probabilistic introduction PDF

  • 2. Connecting local models : The case of chains PDF slides

  • 3. Connecting local models : Beyond chains and trees.PDF slides

  • 4. Machine Learning and Graphs : Introduction and problems PDF slides

  • 5. Graph Neural Networks. PDF slides

  • 6. Graph Kernels. PDF slides

  • 7. Appendix : Introduction to deep learning. PDF slides

In [1]:
#load of import
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.animation as animation
%matplotlib inline  
from IPython.display import HTML

Zachary’s Karate Club

Briefly, Zachary’s Karate Club is a small social network where a conflict arises between the administrator and instructor in a karate club. The task is to predict which side of the conflict each member of the karate club chooses. The graph representation of the network can be seen below. Each node represents a member of the karate club and a link between members indicate that they interact outside the club. The Administrator and Instructor marked with A and I, respectively.

zakaryclub.png

In [2]:
from networkx import karate_club_graph, to_numpy_matrix
G = karate_club_graph()
#order = sorted(list(G.nodes()))
nx.draw(G, with_labels=True, font_weight='bold')
plt.show()
print('We have %d nodes.' % G.number_of_nodes())
print('We have %d edges.' % G.number_of_edges())
We have 34 nodes.
We have 78 edges.

Assign features to nodes or edges

Graph neural networks associate features with nodes and edges for training. For our classification example, we assign each node an input feature as a one-hot vector: node $v_i$‘s feature vector is $[0,…,1,…,0]$, where the $i^{th}$ position is one.

In DGL, you can add features for all nodes at once, using a feature tensor that batches node features along the first dimension. The code below adds the one-hot feature for all nodes:

In [3]:
import torch

eye=np.eye(34)

node_label = {}
for i in range(G.number_of_nodes()):
    node_label[i]=eye[i,:]
nx.set_node_attributes(G,node_label,'feature')


print(len(G.nodes))
print(len(G.edges))
print(G.number_of_nodes())
print(G.number_of_edges())


listnodes = list(G.nodes(data='feature'))
n=listnodes[0]
print(n)
print(n[1].shape)
34
78
34
78
(0, array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))
(34,)

Let 's us get the node features and the adjacency matrix

In [4]:
nodelist, nodesfeatures = map(list, zip(*G.nodes(data='feature')))
print(len(nodesfeatures))
nodesfeatures = np.array(nodesfeatures)
print(nodesfeatures.shape)
nodesfeatures = torch.from_numpy(nodesfeatures).float()
adjacencymatrix = np.array(nx.adjacency_matrix(G, nodelist=nodelist).todense())
print(adjacencymatrix.shape)
adjacencymatrix = torch.from_numpy(adjacencymatrix).float()
34
(34, 34)
(34, 34)

Data preparation and initialization

We use one-hot vectors to initialize the node features. Since this is a semi-supervised setting, only the instructor (node 0) and the club president (node 33) are assigned labels. The implementation is available as follow.

In [5]:
#inputs = torch.eye(34)
labeled_nodes = torch.tensor([0, 33])  # only the instructor and the president nodes are labeled
print(labeled_nodes)
labels = torch.tensor([0, 1])  # their labels are different
print(labels)
tensor([ 0, 33])
tensor([0, 1])

Define a Model : Graph Convolution layer

Firstly, we have to define a Graph Convolution layer Graph Convolution Graph Convolution

$A_j=A^j$. It encodes j-hop neighbourhood of each node and allows to aggregate local information at different scales, which is usefull for regular graphs.

By denoting $\mathcal{A}=\{I,A,A_2,A_3\}$, a GNN layer is defined as : Graph Convolution

Note that the U operator is not used.

In [6]:
import torch
import torch.nn as nn

class GraphConvolution(nn.Module):
  """
     Graph convolution layer
  """
  
  def __init__(self, in_features, out_features, bias=True, batchnorm=False):
    super(GraphConvolution, self).__init__()
    self.in_features = in_features
    self.out_features = out_features
    self.bias = bias
    self.fc = nn.Linear(4*self.in_features, self.out_features, bias=self.bias)
    
    self.batchnorm = batchnorm
    
      
  #H are node features for all graphs batch
  #W are adjacency matrix for all graphs batch
  #A is the set of operators
  # GraphConv = A[0].H_0.W_0 
  def forward(self, H, A):
    res = torch.zeros((H.shape[0],self.in_features*4))
     
    
    output1 = torch.matmul(A[0], H)
    res[:,0:self.in_features]=output1
    
    output2 = torch.matmul(A[1], H)
    res[:,self.in_features:2*self.in_features]=output2
    
    output3 = torch.matmul(A[2], H)
    res[:,2*self.in_features:3*self.in_features]=output3
    
    output4 = torch.matmul(A[3], H)
    res[:,3*self.in_features:4*self.in_features]=output4

    #FC is just a linear function input multiplied by the paramaters W
    output = self.fc(res)
    
    return output

Let us define the network : 2 graph conv layers

In [7]:
# A Simple model with 2 graph conv layers
# activation function are ReLus
class Net(nn.Module):
  def __init__(self, in_dim, hidden_dim, n_classes):
    super(Net, self).__init__()
    self.layers = nn.ModuleList([
        GraphConvolution(in_dim, hidden_dim),
        GraphConvolution(hidden_dim, n_classes)])
    
  def forward(self, h, adj):
    # Add self connections to the adjacency matrix
    id = torch.eye(h.shape[0])
    adj2=torch.pow(adj,2)
    adj3=torch.pow(adj,3)
    #one = torch.ones(h.shape[0])
    for conv in self.layers:
      h = F.relu(conv(h, [id,adj,adj2,adj3]))
    
    
    return h

Let us train : Nothing to change

The training loop is exactly the same as other PyTorch models. We (1) create an optimizer, (2) feed the inputs to the model, (3) calculate the loss and (4) use autograd to optimize the model

Semi-supervised learning

image.png Where $Z$ is the output of the graph neural network. $l_{reg}$ can be computed as follows by the Frobenius norm : $$l_{reg}=||Z.Z^T -A ||_{F}^2 $$.

Let us start with the supervised learning version

In [8]:
nb_channels=34 
num_class=2
num_hidden=5
model = Net(nb_channels, num_hidden,num_class )
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
lossfunction =  torch.nn.CrossEntropyLoss()
all_logits = []
optimizer.zero_grad()
for epoch in range(30):
    prediction = model(nodesfeatures, adjacencymatrix)
    # we save the prediction for visualization later
    all_logits.append(prediction.detach())
    
    # we only compute loss for labeled nodes
    loss0 = lossfunction(prediction[labeled_nodes], labels.long())
    loss=loss0
    #The crossentropy loss does the same as 
    #logp = F.log_softmax(prediction, 1)
    #loss0 = F.nll_loss(logp[labeled_nodes], labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print('Epoch %d | Loss: %.4f' % (epoch, loss.item()))
Epoch 0 | Loss: 0.5651
Epoch 1 | Loss: 0.3199
Epoch 2 | Loss: 0.1667
Epoch 3 | Loss: 0.0628
Epoch 4 | Loss: 0.0201
Epoch 5 | Loss: 0.0060
Epoch 6 | Loss: 0.0021
Epoch 7 | Loss: 0.0008
Epoch 8 | Loss: 0.0003
Epoch 9 | Loss: 0.0001
Epoch 10 | Loss: 0.0001
Epoch 11 | Loss: 0.0000
Epoch 12 | Loss: 0.0000
Epoch 13 | Loss: 0.0000
Epoch 14 | Loss: 0.0000
Epoch 15 | Loss: 0.0000
Epoch 16 | Loss: 0.0000
Epoch 17 | Loss: 0.0000
Epoch 18 | Loss: 0.0000
Epoch 19 | Loss: 0.0000
Epoch 20 | Loss: 0.0000
Epoch 21 | Loss: 0.0000
Epoch 22 | Loss: 0.0000
Epoch 23 | Loss: 0.0000
Epoch 24 | Loss: 0.0000
Epoch 25 | Loss: 0.0000
Epoch 26 | Loss: 0.0000
Epoch 27 | Loss: 0.0000
Epoch 28 | Loss: 0.0000
Epoch 29 | Loss: 0.0000

Visualization

This is a rather toy example, so it does not even have a validation or test set. Instead, Since the model produces an output feature of size 2 for each node, we can visualize by plotting the output feature in a 2D space. The following code animates the training process from initial guess (where the nodes are not classified correctly at all) to the end (where the nodes are linearly separable).

In [9]:
def draw(i):
    cls1color = '#00FFFF'
    cls2color = '#FF00FF'
    pos = {}
    colors = []
    for v in range(34):
        pos[v] = all_logits[i][v].numpy()
        cls = pos[v].argmax()
        colors.append(cls1color if cls else cls2color)
    ax.cla()
    ax.axis('off')
    ax.set_title('Epoch: %d' % i)
    nx.draw_networkx(G, pos, node_color=colors,
            with_labels=True, node_size=300, ax=ax)

fig = plt.figure(dpi=150)
fig.clf()
ax = fig.subplots()
draw(0)  # draw the prediction of the first epoch
plt.close()
In [10]:
ani = animation.FuncAnimation(fig, draw, frames=len(all_logits), interval=200)
HTML(ani.to_html5_video())
Out[10]:

Unsupervised Version only lreg

In [166]:
def regularization(ypred,adj):
    transpo=ypred.t()
    mult=ypred.matmul(transpo)
    dif = mult-adj
    res = torch.norm(dif, p='fro')
    return res**2

nb_channels=34 
num_class=2
num_hidden=2
model = Net(nb_channels, num_hidden,num_class )
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
lossfunction =  torch.nn.CrossEntropyLoss()
all_logits = []
for epoch in range(100):
    prediction = model(nodesfeatures, adjacencymatrix)
    # we save the prediction for visualization later
    all_logits.append(prediction.detach())
    
    lossreg = regularization(prediction,adjacencymatrix)
    loss=lossreg
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print('Epoch %d | Loss: %.4f' % (epoch, loss.item()))
Epoch 0 | Loss: 131.5302
Epoch 1 | Loss: 128.1530
Epoch 2 | Loss: 126.1007
Epoch 3 | Loss: 124.9050
Epoch 4 | Loss: 124.2817
Epoch 5 | Loss: 124.0941
Epoch 6 | Loss: 123.8488
Epoch 7 | Loss: 123.3331
Epoch 8 | Loss: 122.6722
Epoch 9 | Loss: 121.9526
Epoch 10 | Loss: 121.2013
Epoch 11 | Loss: 120.4059
Epoch 12 | Loss: 119.5302
Epoch 13 | Loss: 118.5621
Epoch 14 | Loss: 117.6331
Epoch 15 | Loss: 116.8341
Epoch 16 | Loss: 116.1008
Epoch 17 | Loss: 115.4888
Epoch 18 | Loss: 115.0089
Epoch 19 | Loss: 114.6506
Epoch 20 | Loss: 114.3429
Epoch 21 | Loss: 114.0592
Epoch 22 | Loss: 113.8173
Epoch 23 | Loss: 113.5870
Epoch 24 | Loss: 113.3656
Epoch 25 | Loss: 113.1522
Epoch 26 | Loss: 112.9321
Epoch 27 | Loss: 112.7247
Epoch 28 | Loss: 112.5516
Epoch 29 | Loss: 112.4059
Epoch 30 | Loss: 112.2662
Epoch 31 | Loss: 112.1294
Epoch 32 | Loss: 112.0095
Epoch 33 | Loss: 111.9115
Epoch 34 | Loss: 111.8205
Epoch 35 | Loss: 111.7281
Epoch 36 | Loss: 111.6472
Epoch 37 | Loss: 111.5863
Epoch 38 | Loss: 111.5371
Epoch 39 | Loss: 111.4904
Epoch 40 | Loss: 111.4504
Epoch 41 | Loss: 111.4207
Epoch 42 | Loss: 111.3915
Epoch 43 | Loss: 111.3553
Epoch 44 | Loss: 111.3166
Epoch 45 | Loss: 111.2787
Epoch 46 | Loss: 111.2371
Epoch 47 | Loss: 111.1899
Epoch 48 | Loss: 111.1432
Epoch 49 | Loss: 111.1003
Epoch 50 | Loss: 111.0581
Epoch 51 | Loss: 111.0163
Epoch 52 | Loss: 110.9787
Epoch 53 | Loss: 110.9457
Epoch 54 | Loss: 110.9136
Epoch 55 | Loss: 110.8823
Epoch 56 | Loss: 110.8529
Epoch 57 | Loss: 110.8246
Epoch 58 | Loss: 110.7944
Epoch 59 | Loss: 110.7633
Epoch 60 | Loss: 110.7329
Epoch 61 | Loss: 110.7022
Epoch 62 | Loss: 110.6709
Epoch 63 | Loss: 110.6410
Epoch 64 | Loss: 110.6129
Epoch 65 | Loss: 110.5858
Epoch 66 | Loss: 110.5597
Epoch 67 | Loss: 110.5351
Epoch 68 | Loss: 110.5112
Epoch 69 | Loss: 110.4870
Epoch 70 | Loss: 110.4624
Epoch 71 | Loss: 110.4377
Epoch 72 | Loss: 110.4120
Epoch 73 | Loss: 110.3853
Epoch 74 | Loss: 110.3580
Epoch 75 | Loss: 110.3245
Epoch 76 | Loss: 110.2910
Epoch 77 | Loss: 110.2586
Epoch 78 | Loss: 110.2273
Epoch 79 | Loss: 110.1968
Epoch 80 | Loss: 110.1666
Epoch 81 | Loss: 110.1350
Epoch 82 | Loss: 110.1019
Epoch 83 | Loss: 110.0673
Epoch 84 | Loss: 110.0319
Epoch 85 | Loss: 109.9959
Epoch 86 | Loss: 109.9600
Epoch 87 | Loss: 109.9238
Epoch 88 | Loss: 109.8868
Epoch 89 | Loss: 109.8502
Epoch 90 | Loss: 109.8128
Epoch 91 | Loss: 109.7732
Epoch 92 | Loss: 109.7319
Epoch 93 | Loss: 109.6887
Epoch 94 | Loss: 109.6435
Epoch 95 | Loss: 109.5962
Epoch 96 | Loss: 109.5471
Epoch 97 | Loss: 109.4952
Epoch 98 | Loss: 109.4408
Epoch 99 | Loss: 109.3795
In [167]:
def draw(i):
    cls1color = '#00FFFF'
    cls2color = '#FF00FF'
    pos = {}
    colors = []
    for v in range(34):
        pos[v] = all_logits[i][v].numpy()
        cls = pos[v].argmax()
        colors.append(cls1color if cls else cls2color)
    ax.cla()
    ax.axis('off')
    ax.set_title('Epoch: %d' % i)
    nx.draw_networkx(G, pos, node_color=colors,
            with_labels=True, node_size=300, ax=ax)

fig = plt.figure(dpi=150)
fig.clf()
ax = fig.subplots()
draw(0)  # draw the prediction of the first epoch
plt.close()
In [168]:
ani = animation.FuncAnimation(fig, draw, frames=len(all_logits), interval=200)
HTML(ani.to_html5_video())
Out[168]:

Semi-Supervised version

In [188]:
def regularization(ypred,adj):
    transpo=ypred.t()
    mult=ypred.matmul(transpo)
    dif = mult-adj
    res = torch.norm(dif, p='fro')
    #return res
    return res**2

nb_channels=34 
num_class=2
num_hidden=2
lambdaa=0.0001
model = Net(nb_channels, num_hidden,num_class )
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
lossfunction =  torch.nn.CrossEntropyLoss()
all_logits = []
for epoch in range(100):
    prediction = model(nodesfeatures, adjacencymatrix)
    # we save the prediction for visualization later
    all_logits.append(prediction.detach())
    
    loss0 = lossfunction(prediction[labeled_nodes], labels.long())
    
    
    #logp = F.log_softmax(prediction, 1)
    lossreg = regularization(prediction,adjacencymatrix)
    
    loss=loss0+lambdaa*lossreg
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print('Epoch %d | Loss: %.4f' % (epoch, loss.item()))
Epoch 0 | Loss: 0.6578
Epoch 1 | Loss: 0.5114
Epoch 2 | Loss: 0.3937
Epoch 3 | Loss: 0.3292
Epoch 4 | Loss: 0.2821
Epoch 5 | Loss: 0.2479
Epoch 6 | Loss: 0.2165
Epoch 7 | Loss: 0.1947
Epoch 8 | Loss: 0.1829
Epoch 9 | Loss: 0.1757
Epoch 10 | Loss: 0.1668
Epoch 11 | Loss: 0.1556
Epoch 12 | Loss: 0.1443
Epoch 13 | Loss: 0.1362
Epoch 14 | Loss: 0.1322
Epoch 15 | Loss: 0.1291
Epoch 16 | Loss: 0.1241
Epoch 17 | Loss: 0.1177
Epoch 18 | Loss: 0.1125
Epoch 19 | Loss: 0.1089
Epoch 20 | Loss: 0.1073
Epoch 21 | Loss: 0.1051
Epoch 22 | Loss: 0.1023
Epoch 23 | Loss: 0.0994
Epoch 24 | Loss: 0.0970
Epoch 25 | Loss: 0.0959
Epoch 26 | Loss: 0.0957
Epoch 27 | Loss: 0.0951
Epoch 28 | Loss: 0.0937
Epoch 29 | Loss: 0.0919
Epoch 30 | Loss: 0.0905
Epoch 31 | Loss: 0.0897
Epoch 32 | Loss: 0.0895
Epoch 33 | Loss: 0.0893
Epoch 34 | Loss: 0.0888
Epoch 35 | Loss: 0.0880
Epoch 36 | Loss: 0.0872
Epoch 37 | Loss: 0.0866
Epoch 38 | Loss: 0.0863
Epoch 39 | Loss: 0.0863
Epoch 40 | Loss: 0.0862
Epoch 41 | Loss: 0.0860
Epoch 42 | Loss: 0.0856
Epoch 43 | Loss: 0.0852
Epoch 44 | Loss: 0.0850
Epoch 45 | Loss: 0.0849
Epoch 46 | Loss: 0.0849
Epoch 47 | Loss: 0.0847
Epoch 48 | Loss: 0.0845
Epoch 49 | Loss: 0.0843
Epoch 50 | Loss: 0.0841
Epoch 51 | Loss: 0.0839
Epoch 52 | Loss: 0.0838
Epoch 53 | Loss: 0.0837
Epoch 54 | Loss: 0.0836
Epoch 55 | Loss: 0.0834
Epoch 56 | Loss: 0.0832
Epoch 57 | Loss: 0.0830
Epoch 58 | Loss: 0.0829
Epoch 59 | Loss: 0.0828
Epoch 60 | Loss: 0.0827
Epoch 61 | Loss: 0.0825
Epoch 62 | Loss: 0.0824
Epoch 63 | Loss: 0.0823
Epoch 64 | Loss: 0.0821
Epoch 65 | Loss: 0.0820
Epoch 66 | Loss: 0.0819
Epoch 67 | Loss: 0.0818
Epoch 68 | Loss: 0.0817
Epoch 69 | Loss: 0.0815
Epoch 70 | Loss: 0.0814
Epoch 71 | Loss: 0.0813
Epoch 72 | Loss: 0.0811
Epoch 73 | Loss: 0.0810
Epoch 74 | Loss: 0.0808
Epoch 75 | Loss: 0.0807
Epoch 76 | Loss: 0.0805
Epoch 77 | Loss: 0.0803
Epoch 78 | Loss: 0.0802
Epoch 79 | Loss: 0.0800
Epoch 80 | Loss: 0.0799
Epoch 81 | Loss: 0.0797
Epoch 82 | Loss: 0.0796
Epoch 83 | Loss: 0.0794
Epoch 84 | Loss: 0.0793
Epoch 85 | Loss: 0.0792
Epoch 86 | Loss: 0.0790
Epoch 87 | Loss: 0.0789
Epoch 88 | Loss: 0.0788
Epoch 89 | Loss: 0.0786
Epoch 90 | Loss: 0.0785
Epoch 91 | Loss: 0.0784
Epoch 92 | Loss: 0.0783
Epoch 93 | Loss: 0.0781
Epoch 94 | Loss: 0.0780
Epoch 95 | Loss: 0.0779
Epoch 96 | Loss: 0.0778
Epoch 97 | Loss: 0.0777
Epoch 98 | Loss: 0.0776
Epoch 99 | Loss: 0.0775
In [189]:
def draw(i):
    cls1color = '#00FFFF'
    cls2color = '#FF00FF'
    pos = {}
    colors = []
    for v in range(34):
        pos[v] = all_logits[i][v].numpy()
        cls = pos[v].argmax()
        colors.append(cls1color if cls else cls2color)
    ax.cla()
    ax.axis('off')
    ax.set_title('Epoch: %d' % i)
    nx.draw_networkx(G, pos, node_color=colors,
            with_labels=True, node_size=300, ax=ax)

fig = plt.figure(dpi=150)
fig.clf()
ax = fig.subplots()
draw(0)  # draw the prediction of the first epoch
plt.close()
In [190]:
ani = animation.FuncAnimation(fig, draw, frames=len(all_logits), interval=200)
HTML(ani.to_html5_video())
Out[190]:
In [ ]: