diff --git a/qtensor/optimisation/kahypar_ordering/use_kahypar.py b/qtensor/optimisation/kahypar_ordering/use_kahypar.py index e3f3f08a..62b8b3ac 100644 --- a/qtensor/optimisation/kahypar_ordering/use_kahypar.py +++ b/qtensor/optimisation/kahypar_ordering/use_kahypar.py @@ -5,7 +5,7 @@ #from qtensor.optimisation.kahypar_ordering import generate_TN import kahypar as kahypar from os.path import join, abspath, dirname - +import copy # -- Timing from contextlib import contextmanager import time @@ -43,7 +43,7 @@ def set_context(**kwargs): def ka_hg_init(tn): # tn: a dictionary from circ2tn h = tn.values() - l = list({l for word in h for l in word}) #set of unique edges (eg, v_1) + l = list({l for word in h for l in word}) #set of unique vertexes (eg, '0x7ff77d9dfb50') l.sort() nodes = list(range(0, len(l))) edges = [] @@ -113,11 +113,18 @@ def recur_partition(tn,**kwargs): K = int(kwargs.get('K')) while max([len(x) for x in tn_partite_list[layer]]) > K-1: layer += 1 - tn_partite_list.append([]) + result = [] for (count,subgraph) in enumerate(tn_partite_list[layer-1]): if subgraph != {}: # important - tn_partite_list[layer][2*count:2*count] = subgraph_partition(subgraph,**kwargs) - + result.extend(subgraph_partition(subgraph,**kwargs)) + + if result == tn_partite_list[layer-1]: + return tn_partite_list # for large imbalance + else: + tn_partite_list.append(result) + #tn_partite_list[layer].extend(result) + #tn_partite_list[layer][K*count:K*count] = result + # TODO: Adjust the hyperparameters during the partition return tn_partite_list @@ -154,7 +161,6 @@ def tree2order(tn,tn_partite_list): all_edge = list(tn.keys()) layer_num = len(tn_partite_list) order = [] - import copy order_tree = copy.deepcopy(tn_partite_list) t = 0 # count the temp result in order_tree for layer in range(layer_num): @@ -203,154 +209,103 @@ def tree2order(tn,tn_partite_list): order = [x for x in order if type(x) != str] assert len(order) == len(all_edge) - # complete the top of order_tree + + #complete the top of order_tree set_last1 = set(list(tn_partite_list[0][0].keys())) set_last2 = set(list(tn_partite_list[0][1].keys())) result = list(set(all_edge) - set_last1 - set_last2) order_tree= [result] + order_tree + return order,order_tree - # find the order from bottom to top - # correct order_tree, - # TODO: there is some bugs in order insertion (find the index of children) - all_edge = list(tn.keys()) - tn_partite_list = tn_partite_list[::-1] - import copy - order_tree = copy.deepcopy(tn_partite_list) - t = 0 # count the temp result in order_tree - layer_num = len(tn_partite_list) - order = [] - sub_opt = False # local order search for the bottom graph +def order_tree2ec(order_tree,tn,tn_partite_list): + # There is still some bugs in this function + K = len(tn_partite_list[0]) + ec_tree = copy.deepcopy(tn_partite_list) + t = [[]] *(len(ec_tree[-1])*2) + ec_tree.append(t) #contraction tree like Fig in Johnnie's paper + + layer_num = len(ec_tree) for layer in range(layer_num): - if layer == 0: #bottom layer, append order if there is an edge + # if layer == 0: + # parent_graph = order_tree[layer] + # for (count,subgraph) in enumerate(tn_partite_list[layer]): + # add_list=[]; t = 0 + # self_node = subgraph.values() + # self_node = list({l for word in self_node for l in word}) + # for item in parent_graph: + # parent_node = tn.get(item) + # if parent_node != None: + # if any(i in self_node for i in parent_node): + # add_list.append(item) + # t += 1 + # continue + # ec_tree[layer][count] = add_list + + if layer < layer_num - 1: for (count,subgraph) in enumerate(tn_partite_list[layer]): - if sub_opt is True: - # TODO: when the subgraph is small, call other order optimizor) - #order.append(local_search(subgraph)) - continue + if subgraph != {}: + add_list=[]; t = 0 + self_node = subgraph.values() + self_node = list({l for word in self_node for l in word}) + if layer == 0: + parent_graph = order_tree[layer] + else: + parent_ind = find_parent_ind(subgraph,tn_partite_list, layer) + parent_graph = ec_tree[layer-1][parent_ind] + + for item in parent_graph: + parent_node = tn.get(item) + if parent_node != None: + if any(i in self_node for i in parent_node): + add_list.append(item) + t += 1 + continue + + if layer == 0: + ec_tree[layer][count] = add_list + else: + ec_tree[layer][count] = add_list + order_tree[layer][parent_ind] + #eliminate the "temp" ind + ec_tree[layer][count] = [x for x in ec_tree[layer][count] if type(x)!=str] else: - #if subgraph != {} : - result = list(subgraph.keys()) - if result == [] : - result = [f'temp_{t}'] - t = t+1 - order.extend(result) - order_tree[layer][count]=result - else: - #non-bottom layer, need to insert order - for (count,subgraph) in enumerate(tn_partite_list[layer]): - #left_node_empty = 0 - if 1 == 1 : - node_list = list({l for word in list(subgraph.values()) for l in word}) - ind_last = [] - # find the child node of the subgraph - for (count_last,subgraph_last) in enumerate(tn_partite_list[layer-1]): - if subgraph_last != {}: - node_list_last = list({l for word in list(subgraph_last.values()) for l in word}) - check = all(item in node_list for item in node_list_last) - if check is True: - ind_last.append(count_last) - - if len(ind_last) == 2: - last_set1 = set(list(tn_partite_list[layer-1][ind_last[0]].keys())) - last_set2 = set(list(tn_partite_list[layer-1][ind_last[1]].keys())) - result = list(set(subgraph) - last_set1 - last_set2) - if result == [] : - result = [f'temp_{t}'] - t = t + 1 - # for (count_last,subgraph_last) in enumerate(tn_partite_list[layer-1]): - # node_list_last = list({l for word in list(subgraph_last.values()) for l in word}) - # check = all(item in node_list for item in node_list_last) - # if check is True: - # child_ind = count_last - child_set1 = list(order_tree[layer-1][ind_last[0]]) - child_set2 = list(order_tree[layer-1][ind_last[1]]) - - exist_order = [order.index(x) for x in list(child_set1) if x in order] + \ - [order.index(x) for x in list(child_set2) if x in order] - ind = max(exist_order)+1 - order[ind:ind]=result - order_tree[layer][count]=result + ec_tree[layer][count]=[] + elif layer == layer_num - 1: + #TODO: to fix + for (count,_) in enumerate(ec_tree[layer]): + if type(order_tree[layer][count//K]) == list: + ec_tree[layer][count] = order_tree[layer][count//K] + else: + ec_tree[layer][count] = [] - ### there is a single node partition in the subgraph - if len(ind_last) == 1: - last_set = set(list(tn_partite_list[layer-1][ind_last[0]].keys())) - result = list(set(subgraph) - last_set) - if result == [] : - result = [f'temp_{t}'] - t = t + 1 - # for (count_last,subgraph_last) in enumerate(tn_partite_list[layer-1]): - # node_list_last = list({l for word in list(subgraph_last.values()) for l in word}) - # check = all(item in node_list for item in node_list_last) - # if check is True: - # child_ind = count_last - child_set = list(order_tree[layer-1][ind_last[0]]) - exist_order = [order.index(x) for x in list(child_set) if x in order] - ind = max(exist_order)+1 - order[ind:ind]=result - order_tree[layer][count]=result - - ### there are two single node partition in the subgraph - if len(ind_last) == 0: - result = list(subgraph) - if result == [] : - result = [f'temp_{t}'] - t = t + 1 - for (count_last,subgraph_last) in enumerate(tn_partite_list[layer-1]): - node_list_last = list({l for word in list(subgraph_last.values()) for l in word}) - check = all(item in node_list for item in node_list_last) - if check is True: - child_ind = count_last - child_set = list(order_tree[layer-1][child_ind]) - - exist_order = [order.index(x) for x in list(child_set) if x in order] - ind = max(exist_order)+1 - order[ind:ind]=result - order_tree[layer][count] = result - - ''' - #count = tn_partite_list[layer].index(subgraph) - if count % 2 == 0: # left node - # find the contracted edge of its paired node in the same layer - if len(tn_partite_list[layer][count+1].keys()) < 2: - order.extend(result) - order_tree[layer][count]=result - # both left and right subgraphs are empty, new start - else: - left_node_empty = 1 - empty_count = count - left_node_buffer = result - order_tree[layer][count]=result - # follow the order of the right-paired subgraph - else: # right node - if len(tn_partite_list[layer][count-1].keys()) < 2: - order.extend(result) - order_tree[layer][count]=result - # both left and right subgraphs are empty, new start - else: - exist_order = [order.index(x) for x in list(tn_partite_list[layer][count-1]) if x in order] - ind = max(exist_order)+1 - order[ind:ind] = result - order_tree[layer][count]=result - # follow the order of the left-paired subgraph - if left_node_empty != 0: - exist_order = [order.index(x) for x in list(tn_partite_list[layer][empty_count+1]) if x in order] - ind = max(exist_order)+1 - order[ind:ind] = left_node_buffer - ''' - - if layer == layer_num - 1: - set_last1 = set(list(tn_partite_list[layer][0].keys())) - set_last2 = set(list(tn_partite_list[layer][1].keys())) - result = list(set(all_edge) - set_last1 - set_last2) - if result == [] : - result = [f'temp_{t}'] - t = t + 1 - order.extend(result) - order_tree.append(result) + # Count the edge contraction from the ec_tree + # Open edges from two subgraphs - 1 + ec=[] #edge contraction + for layer in range(layer_num): + if layer == 0: + temp = [] + for i in range(K): + temp = temp + ec_tree[layer][i] + result = len(set(temp))-1 + ec.append(result) + elif layer < layer_num - 1: + for (count,subgraph) in enumerate(tn_partite_list[layer-1]): + if subgraph != {}: + child_ind = find_child_ind(subgraph,tn_partite_list, layer-1) + temp = [] + #TODO: to fix + if len(child_ind) > 0: + for i in range(len(child_ind)): + temp = temp + ec_tree[layer][child_ind[i]] + result = len(set(temp))-1 + else: + result = 0 + ec.append(result) + elif layer == layer_num - 1: + for (count,_) in enumerate(ec_tree[layer]): + result = len(set(ec_tree[layer][count//2]+ec_tree[layer][count//K + 1])) + if result > 0: + result = result -1 + ec.append(result) - order = [x for x in order if type(x) != str] - assert len(order) == len(all_edge) - order_tree = order_tree[::-1] - - return order, order_tre \ No newline at end of file + return max(ec) \ No newline at end of file