pgl.graph_wrapper module: Graph data holders for Paddle GNN.

This package provides interface to help building static computational graph for PaddlePaddle.

class pgl.graph_wrapper.BaseGraphWrapper[source]

Bases: object

This module implement base class for graph wrapper.

Currently our PGL is developed based on static computational mode of paddle (we’ll support dynamic computational model later). We need to build model upon a virtual data holder. BaseGraphWrapper provide a virtual graph structure that users can build deep learning models based on this virtual graph. And then feed real graph data to run the models. Moreover, we provide convenient message-passing interface (send & recv) for building graph neural networks.

NOTICE: Don’t use this BaseGraphWrapper directly. Use GraphWrapper and StaticGraphWrapper to create graph wrapper instead.

property edge_feat

Return a dictionary of tensor representing edge features.

Returns

A dictionary whose keys are the feature names and the values are feature tensor.

property edges

Return a tuple of edge Tensor (src, dst).

Returns

A tuple of Tensor (src, dst). Src and dst are both tensor with shape (num_edges, ) and dtype int64.

property graph_lod

Return graph index for graphs

Returns

A variable with shape [None ] as the Lod information of multiple-graph.

indegree()[source]

Return the indegree tensor for all nodes.

Returns

A tensor of shape (num_nodes, ) in int64.

property node_feat

Return a dictionary of tensor representing node features.

Returns

A dictionary whose keys are the feature names and the values are feature tensor.

property num_graph

Return a variable of number of graphs

Returns

A variable with shape (1,) as the number of Graphs in int64.

property num_nodes

Return a variable of number of nodes

Returns

A variable with shape (1,) as the number of nodes in int64.

recv(msg, reduce_function)[source]

Recv message and aggregate the message by reduce_fucntion

The UDF reduce_function function should has the following format.

def reduce_func(msg):
    '''
        Args:
            msg: A LodTensor or a dictionary of LodTensor whose batch_size
                 is equals to the number of unique dst nodes.

        Return:
            It should return a tensor with shape (batch_size, out_dims). The
            batch size should be the same as msg.
    '''
    pass
Parameters
  • msg – A tensor or a dictionary of tensor created by send function..

  • reduce_function – UDF reduce function or strings “sum” as built-in function. The built-in “sum” will use scatter_add to optimized the speed.

Returns

A tensor with shape (num_nodes, out_dims). The output for nodes with no message will be zeros.

send(message_func, nfeat_list=None, efeat_list=None)[source]

Send message from all src nodes to dst nodes.

The UDF message function should has the following format.

def message_func(src_feat, dst_feat, edge_feat):
    '''
        Args:
            src_feat: the node feat dict attached to the src nodes.
            dst_feat: the node feat dict attached to the dst nodes.
            edge_feat: the edge feat dict attached to the
                       corresponding (src, dst) edges.

        Return:
            It should return a tensor or a dictionary of tensor. And each tensor
            should have a shape of (num_edges, dims).
    '''
    pass
Parameters
  • message_func – UDF function.

  • nfeat_list – a list of names or tuple (name, tensor)

  • efeat_list – a list of names or tuple (name, tensor)

Returns

A dictionary of tensor representing the message. Each of the values in the dictionary has a shape (num_edges, dim) which should be collected by recv function.

class pgl.graph_wrapper.GraphWrapper(name, node_feat=[], edge_feat=[], **kwargs)[source]

Bases: pgl.graph_wrapper.BaseGraphWrapper

Implement a graph wrapper that creates a graph data holders that attributes and features in the graph are fluid.layers.data. And we provide interface to_feed to help converting Graph data into feed_dict.

Parameters
  • name – The graph data prefix

  • node_feat – A list of tuples that decribe the details of node feature tenosr. Each tuple mush be (name, shape, dtype) and the first dimension of the shape must be set unknown (-1 or None) or we can easily use Graph.node_feat_info() to get the node_feat settings.

  • edge_feat – A list of tuples that decribe the details of edge feature tenosr. Each tuple mush be (name, shape, dtype) and the first dimension of the shape must be set unknown (-1 or None) or we can easily use Graph.edge_feat_info() to get the edge_feat settings.

Examples

import numpy as np
import paddle.fluid as fluid
from pgl.graph import Graph
from pgl.graph_wrapper import GraphWrapper

place = fluid.CPUPlace()
exe = fluid.Excecutor(place)

num_nodes = 5
edges = [ (0, 1), (1, 2), (3, 4)]
feature = np.random.randn(5, 100)
edge_feature = np.random.randn(3, 100)
graph = Graph(num_nodes=num_nodes,
            edges=edges,
            node_feat={
                "feature": feature
            },
            edge_feat={
                "edge_feature": edge_feature
            })

graph_wrapper = GraphWrapper(name="graph",
            node_feat=graph.node_feat_info(),
            edge_feat=graph.edge_feat_info())

# build your deep graph model
...

# Initialize parameters for deep graph model
exe.run(fluid.default_startup_program())

for i in range(10):
    feed_dict = graph_wrapper.to_feed(graph)
    ret = exe.run(fetch_list=[...], feed=feed_dict )
property holder_list

Return the holder list.

to_feed(graph)[source]

Convert the graph into feed_dict.

This function helps to convert graph data into feed dict for fluid.Excecutor to run the model.

Parameters

graph – the Graph data object

Returns

A dictionary contains data holder names and its corresponding data.

class pgl.graph_wrapper.StaticGraphWrapper(name, graph, place)[source]

Bases: pgl.graph_wrapper.BaseGraphWrapper

Implement a graph wrapper that the data of the graph won’t be changed and it can be fit into the GPU or CPU memory. This can reduce the time of swapping large data from GPU and CPU.

Parameters
  • name – The graph data prefix

  • graph – The static graph that should be put into memory

  • place – fluid.CPUPlace or fluid.CUDAPlace(n) indicating the device to hold the graph data.

Examples

If we have a immutable graph and it can be fit into the GPU or CPU. we can just use a StaticGraphWrapper to pre-place the graph data into devices.

import numpy as np
import paddle.fluid as fluid
from pgl.graph import Graph
from pgl.graph_wrapper import StaticGraphWrapper

place = fluid.CPUPlace()
exe = fluid.Excecutor(place)

num_nodes = 5
edges = [ (0, 1), (1, 2), (3, 4)]
feature = np.random.randn(5, 100)
edge_feature = np.random.randn(3, 100)
graph = Graph(num_nodes=num_nodes,
            edges=edges,
            node_feat={
                "feature": feature
            },
            edge_feat={
                "edge_feature": edge_feature
            })

graph_wrapper = StaticGraphWrapper(name="graph",
            graph=graph,
            place=place)

# build your deep graph model

# Initialize parameters for deep graph model
exe.run(fluid.default_startup_program())

# Initialize graph data
graph_wrapper.initialize(place)
initialize(place)[source]

Placing the graph data into the devices.

Parameters

place – fluid.CPUPlace or fluid.CUDAPlace(n) indicating the device to hold the graph data.