diff --git a/.gitignore b/.gitignore index 3a2a9d2..6de8f23 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,12 @@ *.jl.*.cov *.jl.cov *.jl.mem +*.~ /docs/Manifest*.toml /docs/build/ Manifest.toml .vscode/ -.DS_Store \ No newline at end of file +.DS_Store +.* +*.ipynb diff --git a/Project.toml b/Project.toml index fff4950..0da6e77 100644 --- a/Project.toml +++ b/Project.toml @@ -10,8 +10,10 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5" +GraphIO = "aa1b3936-2fda-51b9-ab35-c553d3a640a2" GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +HyperDualNumbers = "50ceba7f-c3ee-5a84-a6e8-3ad40456ec97" ITensorMPS = "0d1a4710-d33b-49a5-8f18-73bdf49b47e2" ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" @@ -32,8 +34,10 @@ Adapt = "4.3.0" Combinatorics = "1.0.3" Dictionaries = "0.4" EinExprs = "0.6.4" +GraphIO = "0.7.1" GraphRecipes = "0.5.13" Graphs = "1.8.0" +HyperDualNumbers = "4.0.10" ITensorMPS = "0.3.17" ITensors = "0.9" KrylovKit = "0.10.2" diff --git a/cluster/expect-corrected.jl b/cluster/expect-corrected.jl new file mode 100644 index 0000000..1d9a0ba --- /dev/null +++ b/cluster/expect-corrected.jl @@ -0,0 +1,128 @@ +using TensorNetworkQuantumSimulator +using NamedGraphs +using NamedGraphs: AbstractNamedGraph +using Graphs + +const G = Graphs +const NG = NamedGraphs +const TN = TensorNetworkQuantumSimulator +using HyperDualNumbers +using Adapt: adapt +using Dictionaries + +function prep_insertions(obs) + if isnothing(obs) + return (coeffs = identity, op_strings = v->"I") + end + op_strings, verts, _ = TN.collectobservable(obs) + @assert length(verts) <= 2 + + function hyper_coeff(v) + if v==verts[1] + return Hyper(0,1,0,0) + elseif length(verts)==2 && v==verts[2] + return Hyper(0,0,1,0) + else + return 1 + end + end + + function insertion_operator(v) + if v==verts[1] + return op_strings[1] + elseif length(verts)==2 && v==verts[2] + return op_strings[2] + else + return "I" + end + end + return (coeffs = hyper_coeff, op_strings = insertion_operator) +end + +""" +Cluster expansion. See clustercorrections.jl +""" +function cluster_weights(bpc::BeliefPropagationCache, clusters::Vector, egs::Vector, interaction_graph; obs = nothing) + + kwargs = prep_insertions(obs) + + logZbp = TN.free_energy(bpc; kwargs...) + isempty(egs) && return [0], [[logZbp]], [[1]] + + circuit_lengths = sort(unique([c.weight for c=clusters])) + + # Rescale the messages, but deal with the vertices separately + TN.rescale_messages!(bpc) + vns = Dictionary(TN.vertex_scalar(bpc, v; use_epsilon = true, kwargs...) for v=vertices(network(bpc).tensornetwork.graph)) + + # calculate weight of each generalized loop first + wts = TN.weights(bpc, egs; rescales = vns, kwargs...) + + logZs = Array{Array}(undef, length(circuit_lengths) + 1) + logZs[1] = [logZbp] + + coeffs = Array{Array}(undef, length(circuit_lengths) + 1) + coeffs[1] = [1] + + # now calculate contribution to logZ from each cluster + for (cl_i, cl)=enumerate(circuit_lengths) + clusters_cl = filter(c->c.weight==cl, clusters) + logZs[cl_i + 1] = [prod([prod(fill(wts[l],c.multiplicities[l])) for l=c.loop_ids]) for c=clusters_cl] + coeffs[cl_i + 1] = [TN.ursell_function(c, interaction_graph) for c=clusters_cl] + end + + return vcat([0],circuit_lengths), logZs, coeffs +end + +""" +Cluster cumulant expansion. See cumulant-clustercorrections.jl +""" +function cc_weights(bpc::BeliefPropagationCache, regions::Vector, counting_nums::Dict; obs = nothing, rescale::Bool = false) + + kwargs = prep_insertions(obs) + + use_g = findall(gg->counting_nums[gg] != 0, regions) + egs = [induced_subgraph(network(bpc).tensornetwork.graph, gg)[1] for gg=regions[use_g]] + + isempty(egs) && return logZbp, [], [] + + # Rescale the messages, but deal with the vertices separately + if rescale + TN.rescale_messages!(bpc) + vns = Dictionary(TN.vertex_scalar(bpc, v; use_epsilon = true, kwargs...) for v=vertices(network(bpc).tensornetwork.graph)) + else + vns = Dictionary(1 for v=vertices(network(bpc).tensornetwork.graph)) + end + + # calculate weight of each cluster first + wts = TN.weights(bpc, egs; rescales = vns, project_out = false, kwargs...) + + return log.(wts), [counting_nums[gg] for gg=regions[use_g]] +end + +""" +onepoint or twopoint connected correlation function, using cluster cumulant expansion +""" +function cc_correlation(bpc::BeliefPropagationCache, regions::Vector, counting_nums::Dict, obs) + logZs, cnums = cc_weights(bpc, regions, counting_nums; obs = obs) + op_strings, verts, _ = TN.collectobservable(obs) + if length(verts)==1 + return sum(logZs .* cnums).epsilon1 + else + return sum(logZs .* cnums).epsilon12 + end +end + +""" +onepoint or twopoint connected correlation function, using cluster expansion +""" +function cluster_correlation(bpc::BeliefPropagationCache, clusters::Vector, egs::Vector, interaction_graph, obs) + cluster_wts, logZs, ursells = cluster_weights(bpc, clusters, egs, interaction_graph; obs = obs) + op_strings, verts, _ = TN.collectobservable(obs) + cumul_dat = cumsum([sum([logZs[i][j] * ursells[i][j] for j=1:length(logZs[i])]) for i=1:length(logZs)]) + if length(verts)==1 + return cluster_wts, [d.epsilon1 for d=cumul_dat] + else + return cluster_wts, [d.epsilon12 for d=cumul_dat] + end +end \ No newline at end of file diff --git a/examples/clustercorrections.jl b/examples/clustercorrections.jl new file mode 100644 index 0000000..406356b --- /dev/null +++ b/examples/clustercorrections.jl @@ -0,0 +1,97 @@ +using TensorNetworkQuantumSimulator +const TN = TensorNetworkQuantumSimulator + +using ITensors + +using NamedGraphs +using Graphs +const NG = NamedGraphs +const G = Graphs +using NamedGraphs.NamedGraphGenerators: named_grid, named_hexagonal_lattice_graph + +using LinearAlgebra: norm + +using EinExprs: Greedy + +using Random +Random.seed!(1634) + +include("../cluster/expect-corrected.jl") + +function main(nx,ny) + χ = 3 + ITensors.disable_warn_order() + gs = [ + (named_grid((nx, 1)), "line", 0,-1), + (named_hexagonal_lattice_graph(nx, ny), "hexagonal", 6,11), + (named_grid((nx, ny)), "square", 4,11), + ] + + states = [] + for (g, g_str, smallest_loop_size, wmax) in gs + println("*****************************************") + println("Testing for $g_str lattice with $(NG.nv(g)) vertices") + wmax = min(wmax, NG.nv(g)) + ψ = TN.random_tensornetworkstate(ComplexF32, g, "S=1/2"; bond_dimension = χ) + + ψ = normalize(ψ; alg = "bp") + ψIψ = BeliefPropagationCache(ψ) + ψIψ = update(ψIψ) + + # BP expectation value + v = first(center(g)) + expect_bp = real(expect(ψIψ, ("Z", [v]))) + expect_exact_v = real(expect(ψ, ("Z", [v]); alg = "exact")) + clusters, egs, ig = TN.enumerate_clusters(g, wmax; must_contain=[v], min_deg = 1, min_v = smallest_loop_size) + cluster_wts, expects = cluster_correlation(ψIψ,clusters, egs, ig, ("Z", [v])) + + + regs = Dict() + cnums = Dict() + + cc_wts = [1; smallest_loop_size:wmax;] + for w=cc_wts + regs[w],_,cnums[w]=TN.build_region_family_correlation(g,v,v,w) + end + + expects_cc = Dict() + for w=cc_wts + expects_cc[w] = real(cc_correlation(ψIψ,regs[w], cnums[w], ("Z", [v]))) + end + + println("Bp expectation value for Z on site $(v) is $expect_bp") + println("Cluster expansion expectation values: $(cluster_wts), $(real.(expects))") + println("Cluster cumulant expansion: $(cc_wts), $([expects_cc[w] for w=cc_wts])") + println("Exact expectation value is $expect_exact_v") + + println("***********************************") + u = neighbors(g, v)[1] + obs = (["Z","Z"], [u,v]) + expect_exact_u = real(expect(ψ, ("Z", [u]); alg = "exact")) + expect_exact = real(expect(ψ,obs; alg = "exact")) - expect_exact_u * expect_exact_v + println("Calculating connected correlation function between $(v) and $(u)") + + clusters, egs, ig = TN.enumerate_clusters(g, max(1,min(wmax,2*smallest_loop_size)); must_contain=[u,v], min_deg = 1, min_v = 2) + + cluster_wts, expects = cluster_correlation(ψIψ,clusters, egs, ig, obs) + + regs = Dict() + cnums = Dict() + + cc_wts = [2;3:wmax;] + for w=cc_wts + regs[w],_,cnums[w]=TN.build_region_family_correlation(g,u,v,w) + end + + expects_cc = Dict() + for w=cc_wts + expects_cc[w] = real(cc_correlation(ψIψ,regs[w], cnums[w], obs)) + end + println("Cluster expansion expectation values: $(cluster_wts), $(real.(expects))") + println("Cluster cumulant expansion: $(cc_wts), $([expects_cc[w] for w=cc_wts])") + + println("Exact expectation value is $expect_exact") + push!(states, ψ) + end + return states +end \ No newline at end of file diff --git a/src/Apply/full_update.jl b/src/Apply/full_update.jl index 5ec9a67..174deb0 100644 --- a/src/Apply/full_update.jl +++ b/src/Apply/full_update.jl @@ -107,6 +107,7 @@ function optimise_p_q( nfullupdatesweeps = 10, print_fidelity_loss = false, envisposdef = true, + verbose = false, apply_kwargs..., ) p_cur, q_cur = factorize( @@ -138,16 +139,17 @@ function optimise_p_q( b_vec = b(p, q, o, envs, q_cur) M_p_partial = partial(M_p, envs, q_cur, qs_ind) - p_cur, info = linsolve( + p_cur, info1 = linsolve( M_p_partial, b_vec, p_cur; isposdef = envisposdef, ishermitian = false ) b_tilde_vec = b(p, q, o, envs, p_cur) M_p_tilde_partial = partial(M_p, envs, p_cur, ps_ind) - q_cur, info = linsolve( + q_cur, info2 = linsolve( M_p_tilde_partial, b_tilde_vec, q_cur; isposdef = envisposdef, ishermitian = false ) + verbose && println("Linsolve info, iteration $(i): $(info1), $(info2)") end fend = print_fidelity_loss ? fidelity(envs, p_cur, q_cur, p, q, o) : 0 diff --git a/src/MessagePassing/abstractbeliefpropagationcache.jl b/src/MessagePassing/abstractbeliefpropagationcache.jl index 9780c94..bd70e05 100644 --- a/src/MessagePassing/abstractbeliefpropagationcache.jl +++ b/src/MessagePassing/abstractbeliefpropagationcache.jl @@ -18,9 +18,9 @@ function rescale_vertices!( return not_implemented() end -function vertex_scalar(bp_cache::AbstractBeliefPropagationCache, vertex) +function vertex_scalar(bp_cache::AbstractBeliefPropagationCache, vertex; op_strings::Function = v->"I", coeffs::Function = v->1, use_epsilon::Bool = false) incoming_ms = incoming_messages(bp_cache, vertex) - state = bp_factors(bp_cache, vertex) + state = bp_factors(bp_cache, vertex; op_strings = op_strings, coeffs = coeffs, use_epsilon = use_epsilon) contract_list = [state; incoming_ms] sequence = contraction_sequence(contract_list; alg = "optimal") return contract(contract_list; sequence)[] @@ -115,8 +115,8 @@ function edge_scalars( return map(e -> edge_scalar(bp_cache, e; kwargs...), edges) end -function scalar_factors_quotient(bp_cache::AbstractBeliefPropagationCache) - return vertex_scalars(bp_cache), edge_scalars(bp_cache) +function scalar_factors_quotient(bp_cache::AbstractBeliefPropagationCache; kwargs...) + return vertex_scalars(bp_cache; kwargs...), edge_scalars(bp_cache) end function incoming_messages( @@ -194,16 +194,30 @@ function update(alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache) error("You need to specify a number of iterations for BP!") end bpc = copy(bpc) + diffs = zeros(alg.kwargs.maxiter) + tot_iter = alg.kwargs.maxiter + success = false for i in 1:alg.kwargs.maxiter diff = compute_error ? Ref(0.0) : nothing update_iteration!(alg, bpc, alg.kwargs.edge_sequence; (update_diff!) = diff) - if compute_error && (diff.x / length(alg.kwargs.edge_sequence)) <= alg.kwargs.tolerance - if alg.kwargs.verbose - println("BP converged to desired precision after $i iterations.") - end - break + if compute_error + diffs[i] = diff.x + if (diffs[i] / length(alg.kwargs.edge_sequence)) <= alg.kwargs.tolerance + if alg.kwargs.verbose + println("BP converged to desired precision after $i iterations.") + end + success = true + tot_iter = i + break + end end end + if compute_error && alg.kwargs.verbose + if !success + println("Did not converge.") + end + println("Diffs during message passing: $(diffs[1:tot_iter])") + end return bpc end @@ -239,21 +253,39 @@ function Adapt.adapt_structure(to, bpc::AbstractBeliefPropagationCache) return bpc end -function freenergy(bp_cache::AbstractBeliefPropagationCache) - numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) - if any(t -> real(t) < 0, numerator_terms) - numerator_terms = complex.(numerator_terms) +function adapt_complex(t) + if typeof(t)<:Hyper + return Hyper{ComplexF64}(t.value,t.epsilon1,t.epsilon2,t.epsilon12) + else + return complex(t) + end +end + +function get_real_part(t) + if typeof(t)<:Hyper + return real(t.value) + else + return real(t) end - if any(t -> real(t) < 0, denominator_terms) - denominator_terms = complex.(denominator_terms) +end + +function free_energy(bp_cache::AbstractBeliefPropagationCache; op_strings::Function = v->"I", coeffs::Function = v->1) + numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache; op_strings = op_strings, coeffs = coeffs, use_epsilon = true) + + # Skip this piece for now + if any(t -> get_real_part(t) < 0, numerator_terms) + numerator_terms = adapt_complex.(numerator_terms) + end + if any(t -> get_real_part(t) < 0, denominator_terms) + denominator_terms = adapt_complex.(denominator_terms) end any(iszero, denominator_terms) && return -Inf return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) end -function partitionfunction(bp_cache::AbstractBeliefPropagationCache) - return exp(freenergy(bp_cache)) +function partitionfunction(bp_cache::AbstractBeliefPropagationCache; op_strings::Function = v->"I", coeffs::Function = v->1) + return exp(free_energy(bp_cache; op_strings = op_strings, coeffs = coeffs)) end function rescale_messages!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) @@ -278,4 +310,4 @@ function rescale(bpc::AbstractBeliefPropagationCache, args...; kwargs...) bpc = copy(bpc) rescale!(bpc, args...; kwargs...) return bpc -end +end \ No newline at end of file diff --git a/src/MessagePassing/beliefpropagationcache.jl b/src/MessagePassing/beliefpropagationcache.jl index 7da4c33..60a981e 100644 --- a/src/MessagePassing/beliefpropagationcache.jl +++ b/src/MessagePassing/beliefpropagationcache.jl @@ -13,10 +13,13 @@ struct BeliefPropagationCache{V, N <: AbstractTensorNetwork{V}, M <: Union{ITens end #TODO: Take `dot` without precontracting the messages to allow scaling to more complex messages +# we shouldn't let this be negative, it doesn't make sense function message_diff(message_a::ITensor, message_b::ITensor) n_a, n_b = norm(message_a), norm(message_b) f = abs2(dot(message_a, message_b) / (n_a * n_b)) - return 1 - f + + # or do abs(1-f)? + return max(0,1 - f) end messages(bp_cache::BeliefPropagationCache) = bp_cache.messages @@ -86,13 +89,12 @@ end function rescale_vertices!( bpc::BeliefPropagationCache, - vertices::Vector - ) + vertices::Vector; kwargs...) tn = network(bpc) for v in vertices vn = vertex_scalar(bpc, v) - s = isreal(vn) ? sign(vn) : one(vn) + s = isreal(vn) ? sign(vn) : one(vn) if tn isa TensorNetworkState setindex_preserve!(tn, tn[v] * s * inv(sqrt(vn)), v) elseif tn isa TensorNetwork diff --git a/src/MessagePassing/clustercorrections.jl b/src/MessagePassing/clustercorrections.jl new file mode 100644 index 0000000..20d41ab --- /dev/null +++ b/src/MessagePassing/clustercorrections.jl @@ -0,0 +1,216 @@ +using NamedGraphs +using NamedGraphs: AbstractGraph,AbstractNamedGraph + +struct Cluster + loop_ids::Vector{Int} + multiplicities::Dict{Int, Int} + weight::Int + total_loops::Int +end + +struct Loop + vertices::Vector + edges::Vector + weight::Int +end + + +function canonical_cluster_signature(cluster::Cluster) + items = [(loop_id, cluster.multiplicities[loop_id]) for loop_id in sort(cluster.loop_ids)] + return (tuple(items...), cluster.weight) +end + +function build_interaction_graph(loops::Vector{Loop}) + """Build interaction graph with optimizations for speed.""" + interaction_graph = Dict{Int, Vector{Int}}() + n_loops = length(loops) + + println(" Building optimized interaction graph for $n_loops loops...") + flush(stdout) + + # Optimization 1: Pre-compute vertex sets once + vertex_sets = [Set(loop.vertices) for loop in loops] + + # Optimization 2: Build vertex-to-loops mapping for faster lookup + vertex_to_loops = Dict{Int, Vector{Int}}() + for (i, loop) in enumerate(loops) + for vertex in loop.vertices + if !haskey(vertex_to_loops, vertex) + vertex_to_loops[vertex] = Int[] + end + push!(vertex_to_loops[vertex], i) + end + end + + for i in 1:n_loops + interaction_graph[i] = unique(vcat([vertex_to_loops[v] for v=loops[i].vertices]...)) + + end + + return interaction_graph +end + +""" +Enumerate connected clusters using DFS starting from loops supported on target site. +Connectivity is guaranteed by growing through the interaction graph. +Courtesy of Frank Zhang and Siddhant Midha +""" +function dfs_enumerate_clusters_from_supported(all_loops::Vector{Loop}, supported_loop_ids::Vector{Int}, max_weight::Int, interaction_graph::Dict{Int, Vector{Int}}; verbose::Bool = false) + clusters = Cluster[] + seen_clusters = Set{Tuple}() + cluster_count = 0 + + verbose && println(" Starting DFS cluster enumeration...") + verbose && println(" Supported loops: $(length(supported_loop_ids)), Max weight: $max_weight") + + # DFS to grow clusters starting from each supported loop + function dfs_grow_cluster(current_cluster::Vector{Int}, current_weight::Int, + has_supported::Bool) + + # If we've found a valid cluster (has supported loop), record it + if has_supported && current_weight >= 1 + # Create cluster with multiplicities + multiplicities = Dict{Int, Int}() + for loop_id in current_cluster + multiplicities[loop_id] = get(multiplicities, loop_id, 0) + 1 + end + + cluster = Cluster( + collect(keys(multiplicities)), + multiplicities, + current_weight, + length(current_cluster) + ) + + # Avoid duplicates using canonical signature + signature = canonical_cluster_signature(cluster) + if !(signature in seen_clusters) + push!(seen_clusters, signature) + push!(clusters, cluster) + cluster_count += 1 + end + end + + # Stop if we've reached max weight + if current_weight >= max_weight + return + end + + # Find candidate loops to add (adjacent loops or multiplicities) + candidate_loops = Set{Int}() + + if isempty(current_cluster) + # Start with supported loops only + for loop_id in supported_loop_ids + if all_loops[loop_id].weight <= max_weight - current_weight + push!(candidate_loops, loop_id) + end + end + else + # Add loops connected to current cluster via interaction graph + for loop_id in current_cluster + # Add connected loops (touching loops) + for neighbor_id in get(interaction_graph, loop_id, Int[]) + if all_loops[neighbor_id].weight <= max_weight - current_weight + push!(candidate_loops, neighbor_id) + end + end + # Allow multiplicity increases (same loop added again) + if all_loops[loop_id].weight <= max_weight - current_weight + push!(candidate_loops, loop_id) + end + end + end + + # Try each candidate loop + for loop_id in candidate_loops + loop_weight = all_loops[loop_id].weight + new_weight = current_weight + loop_weight + + if new_weight <= max_weight + new_cluster = copy(current_cluster) + push!(new_cluster, loop_id) + new_has_supported = has_supported || (loop_id in supported_loop_ids) + + # Continue DFS (connectivity guaranteed by interaction graph) + dfs_grow_cluster(new_cluster, new_weight, new_has_supported) + end + end + end + + # Start DFS with empty cluster + dfs_grow_cluster(Int[], 0, false) + + verbose && println(" DFS enumeration completed: $cluster_count total clusters found") + return clusters +end + +""" +Build all clusters on named graph ng, up to a given weight. Optionally, must be supported on the vertices must_contain, in which case those vertices can be leaves + +This is overkill as it finds ALL subgraphs first, but my other implementation had bugs +""" +function enumerate_clusters(ng::NamedGraph, max_weight::Int; min_v::Int = 4, triangle_free::Bool = true, must_contain = [], min_deg::Int = 2, verbose::Bool = false) + g = ng.position_graph + ordered_indices = ng.vertices.ordered_indices + + verbose && println("Step 1: find embedded generalized loops") + subgraphs = generate_embedded_graphs(g, max_weight; min_v = min_v, triangle_free = triangle_free, min_deg = min_deg, leaf_vertices = [ng.vertices.index_positions[v] for v=must_contain]) + + # convert into form of LoopEnumeration.jl + loops = [Loop(sort(unique(vcat([[e[1],e[2]] for e=subg]...))), subg, length(subg)) for subg=subgraphs] + + verbose && println("Found $(length(loops)) loops") + + verbose && println("Step 2: Building interaction graph...") + interaction_graph = build_interaction_graph(loops) + + # DFS cluster enumeration + verbose && println("Step 3: DFS cluster enumeration...") + if isempty(must_contain) + supported_loops = [1:length(loops);] + else + supported_loops = findall(el->all(l->ng.vertices.index_positions[l] in el.vertices, must_contain), loops) + verbose && println("$(length(supported_loops)) supported...") + end + + all_clusters = dfs_enumerate_clusters_from_supported(loops, supported_loops, max_weight, interaction_graph, verbose = verbose) + verbose && println("Found $(length(all_clusters)) connected clusters") + + # converting loops into NamedGraphs, for use in tensor_weights + return all_clusters, [generalized_loop_named(l, ordered_indices) for l=loops], interaction_graph + +end + +""" +Convert from Loop into NamedGraph +""" +function generalized_loop_named(loop::Loop, ordered_indices) + g = NamedGraph(ordered_indices[loop.vertices]) + for e=loop.edges + add_edge!(g, ordered_indices[e[1]], ordered_indices[e[2]]) + end + g +end + +function ursell_function(cluster::Cluster, adj::Dict) + """ + Compute the Ursell function φ(W) for a connected cluster W. + """ + total_loops = cluster.total_loops + + if length(cluster.loop_ids) > 2 + for i=1:length(cluster.loop_ids) + for j=1:i-1 + if !(cluster.loop_ids[i] in adj[cluster.loop_ids[j]]) + error("Only implemented clusters corresponding to complete graphs for now, but got $(cluster)") + end + end + end + end + + no_vertices = sum(values(cluster.multiplicities)) + denominator = prod(factorial.(values(cluster.multiplicities))) + numerator = (-1)^(no_vertices - 1)* factorial(no_vertices - 1) + return numerator / denominator +end \ No newline at end of file diff --git a/src/MessagePassing/cumulant-clustercorrections.jl b/src/MessagePassing/cumulant-clustercorrections.jl new file mode 100644 index 0000000..5a4a55d --- /dev/null +++ b/src/MessagePassing/cumulant-clustercorrections.jl @@ -0,0 +1,375 @@ +using Graphs: nv, induced_subgraph, is_connected, connected_components +using NamedGraphs +using NamedGraphs: AbstractNamedGraph, position_graph +using Dictionaries + +const RegionKey = NTuple{N,Int} where {N} + +""" + to_key(vs::AbstractVector{Int})::RegionKey +Convert a collection of vertex IDs to a canonical, sorted tuple key. +""" +function to_key(vs::AbstractVector{Int})::RegionKey + return Tuple(sort!(collect(vs))) +end + +""" + key_intersection(a::RegionKey, b::RegionKey)::Vector{Int} +Return the vertex list of the intersection of two region keys. +""" +function key_intersection(a::RegionKey, b::RegionKey)::Vector{Int} + sa = Set(a); sb = Set(b) + return sort!(collect(intersect(sa, sb))) +end + +""" + is_loopful(g::SimpleGraph, key::RegionKey)::Bool +Check if the induced subgraph on `key` is connected and has at least one cycle. +For connected component `H`, loopfulness is `ne(H) - nv(H) + 1 > 0`. +""" +function is_loopful(g::SimpleGraph, key::RegionKey)::Bool + if length(key) < 3 + return false + end + vs = collect(key) + h, _ = induced_subgraph(g, vs) + # ensure connected (defensive; we try to keep regions connected elsewhere) + if nv(h) == 0 + return false + end + if !is_connected(h) + return false + end + return ne(h) - nv(h) + 1 > 0 +end + +""" + induced_components(g::SimpleGraph, vs::AbstractVector{Int})::Vector{RegionKey} +Return connected components (as RegionKey) of the induced subgraph on `vs`. +""" +function induced_components(g::SimpleGraph, vs::AbstractVector{Int})::Vector{RegionKey} + if isempty(vs) + return RegionKey[] + end + h, vmap = induced_subgraph(g, vs) # h has vertices 1..nv(h), vmap maps h->g + comps = connected_components(h) # comps as vectors of 1..nv(h) + # map back to original vertex IDs via `vs` + return [to_key(vs[c]) for c in comps] +end + +# --- Maximal regions under inclusion ------------------------------------------ + +""" + maximal_regions(regions::Set{RegionKey})::Set{RegionKey} +Select the inclusion-maximal regions from a set. +""" +function maximal_regions(regions::Set{RegionKey})::Set{RegionKey} + keys = collect(regions) + sort!(keys; by=length) # small to large + maximal = Set{RegionKey}() + for i in eachindex(keys) + a = keys[i] + is_sub = false + for j in (i+1):length(keys) + b = keys[j] + if length(Set(a)) < length(Set(b)) && issubset(Set(a), Set(b)) + is_sub = true + break + end + end + if !is_sub + push!(maximal, a) + end + end + return maximal +end + +# --- Close under intersections (with connected components) --------------------- +# Note that keeping the connected components of graphs with >1 component is unnecessary and is included here as legacy. +# use close_under_intersections_connected instead +""" + close_under_intersections(g::SimpleGraph, seed::Set{RegionKey}; loop_only::Bool=true)::Set{RegionKey} +Given `seed` regions, iteratively add intersections (split into connected components), +optionally keeping only loopful components; stop when no new regions appear. Optionally, only keep if component contains `must_contain` vertices. +""" +function close_under_intersections(g::SimpleGraph, seed::Set{RegionKey}; loop_only::Bool=true, must_contain = []) + R = Set(seed) + changed = true + while changed + changed = false + keys = collect(R) + for i in 1:length(keys)-1 + a = keys[i] + for j in (i+1):length(keys) + b = keys[j] + X = key_intersection(a, b) + comps = induced_components(g, X) + + for comp in comps + if loop_only && !is_loopful(g, comp) + continue + end + if !isempty(must_contain) && intersect(must_contain, comp) != must_contain + continue + end + if comp ∉ R + push!(R, comp) + changed = true + end + end + end + end + end + return R +end + +# --- Close under intersections (with connected components) --------------------- + +""" + close_under_intersections(g::SimpleGraph, seed::Set{RegionKey}; loop_only::Bool=true)::Set{RegionKey} +Given `seed` regions, iteratively add intersections. Only keep connected components. +optionally keeping only loopful components; stop when no new regions appear. Optionally, only keep if component contains `must_contain` vertices. +""" +function close_under_intersections_connected(g::SimpleGraph, seed::Set{RegionKey}; loop_only::Bool=true, must_contain = []) + R = Set(seed) + changed = true + while changed + changed = false + keys = collect(R) + for i in 1:length(keys)-1 + a = keys[i] + for j in (i+1):length(keys) + b = keys[j] + X = key_intersection(a, b) + + # must be connected and nonempty + if isempty(X) || !is_connected(induced_subgraph(g, X)[1]) + continue + end + + # must contain the required vertices + if !isempty(must_contain) && intersect(must_contain, X) != must_contain + continue + end + + comp = to_key(X) + # must be loopy, if loop_only + if loop_only && !is_loopful(g, comp) + continue + end + + if comp ∉ R + push!(R, comp) + changed = true + end + end + end + end + return R +end + +# --- Counting numbers (top-down Möbius) --------------------------------------- + +""" + counting_numbers(regions::Set{RegionKey})::Dict{RegionKey,Int} +Compute inclusion–exclusion counting numbers c(r): set c=1 for maximals, +then for other regions c(r) = 1 - sum_{a ⊃ r} c(a). +""" +function counting_numbers(regions::Set{RegionKey},maximals::Set{RegionKey})::Dict{RegionKey,Int} + R = collect(regions) + # sort supersets first (decreasing size) + sort!(R; by=r -> (-length(r), r)) + c = Dict{RegionKey,Int}() + + for r=maximals + c[r] = 1 + end + + # fill others (now supersets are guaranteed to have c set already) + for r in R + if haskey(c, r) + continue + end + s = 0 + for a in R + # proper subset + if length(Set(r)) < length(Set(a)) && issubset(Set(r), Set(a)) + s += get(c, a, 0) + end + end + c[r] = 1 - s + end + return c +end + + +""" +Find all subgraphs of g that contain both u and v, up to C vertices, +and have no other leaves +""" +function vertex_walks_up_to_C_regions(g::AbstractGraph, u::Integer, v::Integer, C::Int; buffer::Int = C) + walks = Set{RegionKey}() + + stack = [(u, Set(), -1,false,0)] # (current vertex, path, previous vertex, has_both,num_steps) + + while !isempty(stack) + node, path, prev_node, has_both,num_steps = pop!(stack) + if node==v + has_both = true + + end + + # contains both u and v, and completed a cycle, or a path from u to v + if has_both && (node==v || node in path) + push!(walks, to_key([vv for vv=union(Set([node]), path)])) + end + + push!(path, node) + @assert length(path) <= C + if num_steps==C + buffer + continue + end + for w in neighbors(g, node) + if w==prev_node + # don't backtrack + continue + end + + if length(path) < C || (w in path) + push!(stack, (w, copy(path), node, has_both, num_steps + 1)) + end + end + + end + return walks +end + +""" +Build clusters on graph g out of the regions regs +""" +function build_clusters(g::SimpleGraph, regs::Set; loop_only::Bool=true, must_contain = [], smart::Bool=true, verbose::Bool=false) + verbose && println("Finding maximal"); flush(stdout) + @time Rmax = maximal_regions(regs) + verbose && println("Finding intersections of $(length(Rmax)) regions"); flush(stdout) + if smart + R = close_under_intersections_connected(g, Rmax;loop_only=loop_only, must_contain = unique(must_contain)) + else + R = close_under_intersections(g, Rmax;loop_only=loop_only, must_contain = unique(must_contain)) + end + verbose && println("Finding counting numbers"); flush(stdout) + @time c = counting_numbers(R, Rmax) + return R, Rmax, c +end + +""" +Maps regions to NamedGraphs +""" +function map_regions_named(R::Set, Rmax::Set, c::Dict, vs_dict::Dictionary) + # map back to names + R = [map(v -> vs_dict[v], set) for set in R] + Rmax = [map(v -> vs_dict[v], set) for set in Rmax] + c = Dict(map(v -> vs_dict[v], key) => val for (key, val) in c) + return R, Rmax, c +end + +# prune branches except those ending at keep_vertices +function prune_cc(g::AbstractGraph, regions::Vector, counting_nums::Dict; keep_vertices = []) + counting_graphs = Dict() + for r=regions + if counting_nums[r] != 0 + eg = induced_subgraph(g, r)[1] + pb = Tuple(sort(prune_branches(eg, keep_vertices))) + # pg = induced_subgraph(eg, prune_branches(eg, keep_vertices))[1] + if haskey(counting_graphs, pb) + counting_graphs[pb] += counting_nums[r] + else + counting_graphs[pb] = counting_nums[r] + end + end + end + counting_graphs +end + +""" + build_region_family_correlation(g::SimpleGraph, u::Int, v::Int, C::Int) +Return (R, Rmax, c) where: + R :: Set{RegionKey} — full region family closed under intersections + Rmax :: Set{RegionKey} — maximal regions used as seeds + c :: Dict{RegionKey,Int} — counting numbers for all r ∈ R +""" +function build_region_family_correlation(g::SimpleGraph, u::Int, v::Int, C::Int; buffer::Int=C, smart::Bool=true, verbose::Bool=false) + verbose && println("Finding graphs"); flush(stdout) + @time regs = vertex_walks_up_to_C_regions(g,u,v,C; buffer = buffer) + R,Rmax,c = build_clusters(g, regs; loop_only = false, must_contain = [u,v], smart=smart, verbose = verbose) +end + +""" + Function to enumerate all generalized regions on a NamedGraph up to size Cluster_size, containing u and v. All other vertices have degree geq 2. + Counting numbers are found via top-down Möbius inversion. + For one-point function, just set u=v. + Returns (R, Rmax, c) where: + R :: Vector{Vector{T}} — full region family closed under intersections + Rmax :: Vector{Vector{T}} — maximal regions (largest generalizd loops) used as seeds and not subsets of any other regions + c :: Dict{Vector{T},Int} — counting numbers for all r ∈ R +""" +function build_region_family_correlation(ng::NamedGraph, u, v, Cluster_size::Int; buffer::Int = Cluster_size, smart::Bool=true, prune::Bool=true,verbose::Bool=false) + g, vs_dict = position_graph(ng), Dictionary([i for i in 1:nv(ng)], collect(vertices(ng))) + mapped_u, mapped_v = ng.vertices.index_positions[u], ng.vertices.index_positions[v] + R, Rmax, c = build_region_family_correlation(g, mapped_u, mapped_v, Cluster_size; buffer = buffer, smart=smart,verbose=verbose) + R, Rmax, c = map_regions_named(R, Rmax, c, vs_dict) + if prune + c = prune_cc(ng,R,c; keep_vertices=[u,v]) + return collect(keys(c)), Rmax, c + else + return R, Rmax, c + end +end + +""" +Generate graphs up to isomorphism and then embed in the larger graph g. +max_v is max number of vertices +min_v is min number of vertices (e.g. 4 on square lattice) +""" +function generate_embedded_leafless_graphs(g::AbstractGraph, max_v::Int; min_v::Int=4, triangle_free::Bool = true, min_deg::Int = 2) + @assert min_deg >= 2 + k = maximum([degree(g,v) for v=vertices(g)]) + mygraphs = [generate_graphs(no_vertices, k*no_vertices; triangle_free = triangle_free, min_deg = min_deg, max_deg = k, connected = true) + for no_vertices=min_v:max_v] + + # now embed each one + embeddings = [[embed_graphs(g, subg) for subg = mygs] for mygs=mygraphs] + subgraphs = vcat(vcat(embeddings...)...) + + # make into loopful induced subgraphs only + # Since they're connected and min_deg >= 2, will by definition be loopful + + regions = unique([to_key(unique(vcat([[e[1], e[2]] for e=subg]...))) for subg = subgraphs]) + return Set{RegionKey}(regions) +end + +""" + build_region_family(g::SimpleGraph, C::Int) +Return (R, Rmax, c) where: + R :: Set{RegionKey} — full region family closed under intersections + Rmax :: Set{RegionKey} — maximal regions used as seeds + c :: Dict{RegionKey,Int} — counting numbers for all r ∈ R +""" +function build_region_family(g::SimpleGraph, C::Int; min_deg::Int=2, min_v::Int=4,triangle_free::Bool=true, smart::Bool=true, verbose::Bool=false) + verbose && println("Finding graphs") + @time regs = generate_embedded_leafless_graphs(g, C; min_deg = min_deg, min_v=min_v,triangle_free=triangle_free) + build_clusters(g, regs; loop_only = true,smart=smart, verbose=verbose) +end + +""" + Function to enumerate all generalized regions on a NamedGraph up to size Cluster_size. + Counting numbers are found via top-down Möbius inversion. + Returns (R, Rmax, c) where: + R :: Vector{Vector{T}} — full region family closed under intersections + Rmax :: Vector{Vector{T}} — maximal regions (largest generalizd loops) used as seeds and not subsets of any other regions + c :: Dict{Vector{T},Int} — counting numbers for all r ∈ R +""" +function build_region_family(ng::NamedGraph, Cluster_size::Int; min_deg::Int=2, min_v::Int=4,triangle_free::Bool=true, smart::Bool=true, verbose::Bool=false) + g, vs_dict = position_graph(ng), Dictionary([i for i in 1:nv(ng)], collect(vertices(ng))) + R, Rmax, c = build_region_family(g, Cluster_size; min_deg = min_deg, min_v=min_v,triangle_free=triangle_free, smart=smart, verbose=verbose) + map_regions_named(R, Rmax, c, vs_dict) +end \ No newline at end of file diff --git a/src/MessagePassing/loopcorrection.jl b/src/MessagePassing/loopcorrection.jl index bbbe7a4..55cc835 100644 --- a/src/MessagePassing/loopcorrection.jl +++ b/src/MessagePassing/loopcorrection.jl @@ -1,4 +1,5 @@ using NamedGraphs.GraphsExtensions: boundary_edges +using HyperDualNumbers function loopcorrected_partitionfunction( bp_cache::BeliefPropagationCache, @@ -65,32 +66,47 @@ end #Get the all edges incident to the region specified by the vector of edges passed function NamedGraphs.GraphsExtensions.boundary_edges( bpc::BeliefPropagationCache, - es::Vector{<:NamedEdge}, - ) + es::Vector{<:NamedEdge}) + vs = unique(vcat(src.(es), dst.(es))) + bpes = NamedEdge[] for v in vs - incoming_es = NamedGraphs.GraphsExtensions.boundary_edges(bpc, [v]; dir = :in) + incoming_es = boundary_edges(bpc, [v]; dir = :in) incoming_es = filter(e -> e ∉ es && reverse(e) ∉ es, incoming_es) append!(bpes, incoming_es) end return bpes end -#Compute the contraction of the bp configuration specified by the edge induced subgraph eg -function weight(bpc::BeliefPropagationCache, eg) +#Compute the contraction of the bp configuration specified by the edge induced subgraph eg. Insert I + epsilon O on up to two sites. +function weight(bpc::BeliefPropagationCache, eg; project_out::Bool = true, op_strings::Function = v->"I", coeffs::Function = v->1, rescales = Dictionary(1 for v=vertices(eg))) vs = collect(vertices(eg)) es = collect(edges(eg)) - bpc, antiprojectors = sim_edgeinduced_subgraph(bpc, eg) - incoming_ms = - ITensor[message(bpc, e) for e in boundary_edges(bpc, es)] - local_tensors = reduce(vcat, [bp_factors(bpc, v) for v in vs]) - ts = [incoming_ms; local_tensors; antiprojectors] + + if project_out + bpc, antiprojectors = sim_edgeinduced_subgraph(bpc, eg) + end + if isempty(es) + incoming_ms = ITensor[message(bpc, e) for e in boundary_edges(bpc, vs; dir=:in)] + else + incoming_ms = ITensor[message(bpc, e) for e in boundary_edges(bpc, es)] + + end + local_tensors = reduce(vcat, bp_factors(bpc, vs; op_strings = op_strings, coeffs = coeffs, use_epsilon = true)) + + if project_out + ts = [incoming_ms; local_tensors; antiprojectors] + else + ts = [incoming_ms; local_tensors] + end + seq = any(hasqns.(ts)) ? contraction_sequence(ts; alg = "optimal") : contraction_sequence(ts; alg = "einexpr", optimizer = Greedy()) - return contract(ts; sequence = seq)[] + output = contract(ts; sequence = seq)[] + return output / prod([rescales[v] for v=vs]) end #Vectorized version of weight -function weights(bpc::BeliefPropagationCache, egs) - return [weight(bpc, eg) for eg in egs] +function weights(bpc::BeliefPropagationCache, egs; kwargs...) + return [weight(bpc, eg; kwargs...) for eg in egs] end diff --git a/src/TensorNetworkQuantumSimulator.jl b/src/TensorNetworkQuantumSimulator.jl index eb59923..2f69ebb 100644 --- a/src/TensorNetworkQuantumSimulator.jl +++ b/src/TensorNetworkQuantumSimulator.jl @@ -9,12 +9,15 @@ include("TensorNetworks/tensornetwork.jl") include("TensorNetworks/tensornetworkstate.jl") include("TensorNetworks/tensornetworkstate_constructors.jl") include("contraction_sequences.jl") +include("graph_enumeration.jl") include("Forms/bilinearform.jl") include("Forms/quadraticform.jl") include("MessagePassing/abstractbeliefpropagationcache.jl") include("MessagePassing/beliefpropagationcache.jl") include("MessagePassing/boundarympscache.jl") include("MessagePassing/loopcorrection.jl") +include("MessagePassing/clustercorrections.jl") +include("MessagePassing/cumulant-clustercorrections.jl") include("graph_ops.jl") include("utils.jl") diff --git a/src/TensorNetworks/tensornetworkstate.jl b/src/TensorNetworks/tensornetworkstate.jl index 410278d..0483faa 100644 --- a/src/TensorNetworks/tensornetworkstate.jl +++ b/src/TensorNetworks/tensornetworkstate.jl @@ -1,4 +1,5 @@ using ITensors: random_itensor +using HyperDualNumbers #TODO: Make this show() nicely. struct TensorNetworkState{V} <: AbstractTensorNetwork{V} @@ -39,7 +40,7 @@ function Base.setindex!(tns::TensorNetworkState, value::ITensor, v) return tns end -function norm_factors(tns::TensorNetworkState, verts::Vector; op_strings::Function = v -> "I") +function norm_factors(tns::TensorNetworkState, verts::Vector; op_strings::Function = v -> "I", coeffs::Function = v->1, use_epsilon::Bool = false) factors = ITensor[] for v in verts sinds = siteinds(tns, v) @@ -47,17 +48,18 @@ function norm_factors(tns::TensorNetworkState, verts::Vector; op_strings::Functi tnv_dag = dag(prime(tnv)) if op_strings(v) == "I" tnv_dag = replaceinds(tnv_dag, prime.(sinds), sinds) - append!(factors, ITensor[tnv, tnv_dag]) + append!(factors, ITensor[coeffs(v) * tnv, tnv_dag]) else - op = adapt(datatype(tnv))(ITensors.op(op_strings(v), only(sinds))) + op = use_epsilon ? Hyper(1,0,0,0) * ITensors.op("I", only(sinds)) + coeffs(v) * ITensors.op(op_strings(v), only(sinds)) : coeffs(v) * ITensors.op(op_strings(v), only(sinds)) append!(factors, ITensor[tnv, tnv_dag, op]) end end return factors end -norm_factors(tns::TensorNetworkState, v) = norm_factors(tns, [v]) -bp_factors(tns::TensorNetworkState, v) = norm_factors(tns, v) +norm_factors(tns::TensorNetworkState, v; kwargs...) = norm_factors(tns, [v]; kwargs...) +bp_factors(tns::TensorNetworkState, v; kwargs...) = norm_factors(tns, v; kwargs...) +bp_factors(tns::TensorNetworkState, verts::Vector; kwargs...) = norm_factors(tns, verts; kwargs...) function default_message(tns::TensorNetworkState, edge::AbstractEdge) linds = virtualinds(tns, edge) diff --git a/src/graph_enumeration.jl b/src/graph_enumeration.jl new file mode 100644 index 0000000..3c9d915 --- /dev/null +++ b/src/graph_enumeration.jl @@ -0,0 +1,107 @@ +using Graphs: loadgraphs +using GraphIO +using Graphs.Experimental +using Graphs.SimpleGraphs +using NamedGraphs # from TensorNetworkQuantumSimulator +using StatsBase + +# Generate non-isomorphic graphs with geng +# on square lattice, choose triangle_free = true to speed things up +function generate_graphs(n::Int, max_e::Int; min_e::Int=0, triangle_free::Bool = true, min_deg::Int=1, max_deg::Int=4, connected::Bool = true) + # Build geng command + if connected + if triangle_free + cmd = `geng -c -t -d$(min_deg) -D$(max_deg) $n $(min_e):$(max_e)` + else + cmd = `geng -c -d$(min_deg) -D$(max_deg) $n $(min_e):$(max_e)` + end + elseif triangle_free + cmd = `geng -t -d$(min_deg) -D$(max_deg) $n $(min_e):$(max_e)` + else + cmd = `geng -d$(min_deg) -D$(max_deg) $n $(min_e):$(max_e)` + end + + graphs = SimpleGraph[] + try open(cmd, "r") do io + for g in loadgraphs(io, GraphIO.Graph6.Graph6Format()) + push!(graphs, g[2]) + end + end + catch e + println("Couldn't find graphs with these parameters: $(cmd)") + end + + # also + return graphs +end + +function map_edges(subg::SimpleGraph, iso_map::Vector) + iso_map_dict = Dict(v[2]=>v[1] for v=iso_map) + [(min(iso_map_dict[src(e)],iso_map_dict[dst(e)]),max(iso_map_dict[src(e)],iso_map_dict[dst(e)])) for e=edges(subg)] +end + +function embed_graphs(g::AbstractGraph, subg::SimpleGraph) + all_embeddings = unique(sort.([map_edges(subg, iso_map) for iso_map=all_subgraphisomorph(g, subg)])) +end + +function is_valid_graph(g; leaf_vertices = []) + vertex_counts = countmap(vcat([[e[1],e[2]] for e=g]...)) + for (v,k)=vertex_counts + if !(v in leaf_vertices) && k < 2 + return false + end + end + return true +end + +""" +Generate graphs up to isomorphism and then embed in the larger graph g. +max_weight is max number of edges +min_v is min number of vertices (e.g. 4 on square lattice) +""" +function generate_embedded_graphs(g::AbstractGraph, max_weight::Int; min_v::Int=4, triangle_free::Bool = true, min_deg::Int = 2, leaf_vertices = []) + + k = maximum([degree(g,v) for v=vertices(g)]) + # max edge weight in max_weight, so the max number of vertices is max_weight-1 + mygraphs = [generate_graphs(no_vertices, max_weight; min_e=no_vertices-1, triangle_free = triangle_free, min_deg = min_deg, max_deg = k, connected = true) for no_vertices=min_v:max_weight+1] + + # only try to embed graphs with at most length(leaf_vertices) leaves + mygraphs = [filter(subg->count(isone, [degree(subg,v) for v=vertices(subg)])<=length(leaf_vertices), mygs) for mygs=mygraphs] + + # now embed each one + embeddings = [[embed_graphs(g, subg) for subg = mygs] for mygs=mygraphs] + subgraphs = filter(g->is_valid_graph(g; leaf_vertices = leaf_vertices), vcat(vcat(embeddings...)...)) +end + +""" + prune_branches(g::AbstractGraph, keep_vertices) + +Return a new graph obtained from `g` by pruning away leaf branches +(iteratively) except those that terminate at any vertex in `keep_vertices`. + +Returns the vertices in the pruned graph +""" +function prune_branches(g::AbstractGraph, keep_vertices) + keep_set = Set(keep_vertices) + alive = Dict(v=>true for v=vertices(g)) + + changed = true + while changed + changed = false + # compute degree within the induced alive-subgraph (count alive neighbors) + for v=vertices(g) + if alive[v] + cnt = sum([alive[u] for u=neighbors(g,v)]) + + # remove if it's a leaf (degree 1) or isolated (degree 0), + # and not in keep_set. + if cnt <= 1 && !(v in keep_set) + alive[v] = false + changed = true + end + end + end + end + + return [v for v=vertices(g) if alive[v]] +end