Over smoothing issue in graph neural network

Original Source Here

Models able to capture all the possible information in graphs: as we saw, graph data is everywhere and takes the form of interconnected nodes with feature vectors. And Yes, we can use some Multilayer Perceptron models to resolve our downstream tasks, but we will be losing the connections that the graph topology offers us. As for convolution neural networks, their mechanisms are dedicated to a special case of graph: grid-structured inputs, where nodes are fully connected with no sparsity. That being said, the only remaining solution is a model that can build upon the information given in both: the nodes’ features and the local structures in our graph, which can ease our downstream task; and that is literally what GNN does.

What tasks do GNNs train on?

Now that we have modestly justified the existence of such models, we shall uncover their usage. In fact, there are a lot of tasks that we can train GNNs on: node classification in large graphs (segment users in a social network based on their attributes and their relations), or whole graph classification (classifying protein structures for pharmaceutical applications). In addition to classification, regression problems can also be formulated on top of graph data, working not only on nodes but on edges as well.

To sum up, graph neural network applications are endless and depend on the user’s objective and the type of data that they possess. For the sake of simplicity, we will focus on the node classification task within a unique graph, where we try to map a subset of nodes graph headed by their feature vectors to a set of predefined categories/classes.

The problem supposes the existence of a training set where we have a labeled set of nodes, and that all nodes in our graph got a certain features vector that we note x. Our goal is to predict labels of featured nodes in a validation set.

Node classification example: all nodes have a feature vector; colored nodes are labeled, while those in white are unlabeled [illustration by author]

GNN under the hood

Now that we have set our problem, it’s time to understand how the GNN model will be trained to output classes for unlabelled nodes. In fact, we want our model not only to use the feature vectors of our nodes but also to take advantage of the graph structure of which we dispose of.

This last statement, which makes the GNN unique, must be boxed within a certain hypothesis which states that neighbor nodes tend to share the same label. GNN incorporates that by using message passing formalism, a concept that will be discussed further in this article. We will introduce some of the bottlenecks that we will consider posteriorly.

Enough abstraction, let us now look at how are GNNs constructed. In fact, a GNN model contains a series of layers that communicate via an updated node representation (each layer outputs an embedding vector for each node, which is then used as an input for the next layer to build upon it).

The purpose of our model is to construct these embeddings (for each node) integrating both the nodes’ initial feature vector and the information about the local graph structure that surrounds them. Once we have well-representing embeddings, we feed a classical Softmax layer to those embeddings to output relative classes.

The goal of GNN is to transform node features to features that are aware of the graph structure [illustration by author]

To build those embeddings, GNN layers use a straightforward mechanism called message passing, which helps graph nodes exchange information with their neighbors, and thus update their embedding vector layer after layer.

  • The message passing framework

It all starts with some nodes with vector x describing their attributes, then each node collects other feature vectors from its neighbors’ nodes via a permutation equivariant function (mean, max, min ..). In other words, a function that is not sensible to nodes ordering. This operation is called Aggregate, and it outputs a Message vector.

The second step is the Update function, where the node combines the information gathered from its neighbors (Message vector) with its own information (its feature vector ) to construct a new vector h: Embedding.

The instantiation of this aggregate and update functions differ from one paper to another. You can refer to GCN[1], GraphSage[2], GAT[3], or others, but the message passing idea stays the same.


Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: