Skip to content

More efficient local-push q.pop(0) instead of q.pop() ? #8

@baojian

Description

@baojian

Hi Dr. Gasteiger,

I'd like to know about the implementation of _calc_ppr_node(), where each active node is popped out (.ie., node = q.pop()). I tested this pop() method. It seems less efficient than pop(0), where the list is used as a queue, not a stack. Note that pop() removes the last element of the list. Here is my testing code:

import time
import numpy as np
import networkx as nx
from numba import njit
from numpy.linalg import norm
from numpy import sqrt
from numpy import int64
from numpy import float64

import numba


@njit(cache=True, locals={'_val': numba.float32, 'res': numba.float32, 'res_vnode': numba.float32})
def _calc_ppr_node(inode, indptr, indices, deg, alpha, epsilon):
    alpha_eps = alpha * epsilon
    f32_0 = numba.float32(0)
    p = {inode: f32_0}
    r = {inode: alpha}
    q = [inode]
    total_opers = 0.
    while len(q) > 0:
        unode = q.pop()
        res = r[unode] if unode in r else f32_0
        if unode in p:
            p[unode] += res
        else:
            p[unode] = res
        r[unode] = f32_0
        total_opers += deg[unode]
        for vnode in indices[indptr[unode]:indptr[unode + 1]]:
            _val = (1 - alpha) * res / deg[unode]
            if vnode in r:
                r[vnode] += _val
            else:
                r[vnode] = _val
            res_vnode = r[vnode] if vnode in r else f32_0
            if res_vnode >= alpha_eps * deg[vnode]:
                if vnode not in q:
                    q.append(vnode)
    return list(p.keys()), list(p.values()), total_opers


@njit(cache=True, locals={'_val': numba.float32, 'res': numba.float32, 'res_vnode': numba.float32})
def _calc_ppr_node_pop_first(inode, indptr, indices, deg, alpha, epsilon):
    alpha_eps = alpha * epsilon
    f32_0 = numba.float32(0)
    p = {inode: f32_0}
    r = {inode: alpha}
    q = [inode]
    total_opers = 0.
    while len(q) > 0:
        unode = q.pop(0) # only changed part is here !
        res = r[unode] if unode in r else f32_0
        if unode in p:
            p[unode] += res
        else:
            p[unode] = res
        r[unode] = f32_0
        total_opers += deg[unode]
        for vnode in indices[indptr[unode]:indptr[unode + 1]]:
            _val = (1 - alpha) * res / deg[unode]
            if vnode in r:
                r[vnode] += _val
            else:
                r[vnode] = _val
            res_vnode = r[vnode] if vnode in r else f32_0
            if res_vnode >= alpha_eps * deg[vnode]:
                if vnode not in q:
                    q.append(vnode)
    return list(p.keys()), list(p.values()), total_opers


@njit(cache=True)
def calc_ppr(indptr, indices, deg, alpha, epsilon, nodes, model='pop-last'):
    js = []
    vals = []
    opers = []
    for i, node in enumerate(nodes):
        if model == 'pop-last':
            j, val, oper = _calc_ppr_node(node, indptr, indices, deg, alpha, epsilon)
        else:
            j, val, oper = _calc_ppr_node_pop_first(node, indptr, indices, deg, alpha, epsilon)
        js.append(j)
        vals.append(val)
        opers.append(oper)
    return js, vals, opers


def toy_example_graph():
    """
    From GraphSAGE: https://arxiv.org/pdf/1706.02216.pdf
    """
    n = 15
    adj_list = {0: [1, 12], 1: [0, 2, 5, 9], 2: [1], 3: [5, 6], 4: [5],
                5: [1, 3, 4, 6, 9], 6: [3, 5], 7: [8], 8: [7, 9], 9: [1, 5, 8, 11, 12],
                10: [11], 11: [9, 10], 12: [0, 9, 13, 14], 13: [12, 14], 14: [12, 13]}
    graph = nx.Graph()
    for u in range(n):
        for v in adj_list[u]:
            graph.add_edge(u, v)
    csr_graph = nx.to_scipy_sparse_array(graph, nodelist=range(n))
    degree = csr_graph.indptr[1:] - csr_graph.indptr[:n]
    indices = csr_graph.indices
    indptr = csr_graph.indptr
    n = len(degree)
    return n, indptr, indices, degree


def test_toy_example_graph():
    n, indptr, indices, degree = toy_example_graph()
    s_node = 0  # source node
    s = np.zeros(n, dtype=np.float64)
    s[s_node] = 1.
    alpha = 0.1  # dumping factor
    eps = 1e-4  # precision parameter
    # n, indptr, indices, degree, s, alpha, eps, opt_x
    js, vals, opers = calc_ppr(indptr, indices, degree, alpha, eps, [0], model='pop-last')
    print(np.sum(vals[0]), opers[0])
    js, vals, opers = calc_ppr(indptr, indices, degree, alpha, eps, [0], model='pop-first')
    print(np.sum(vals[0]), opers[0])


test_toy_example_graph()

The above gives me the output:

0.9981211414560676 9704.0
0.9982607923448086 1062.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions