From ddd1900da63ccc050b692534d4fd83a3440f3ab8 Mon Sep 17 00:00:00 2001 From: Michael Dales Date: Wed, 22 Nov 2023 11:20:53 +0000 Subject: [PATCH 01/12] Test using full K for generating S --- methods/matching/find_pairs.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/methods/matching/find_pairs.py b/methods/matching/find_pairs.py index 7f17782..2f24524 100644 --- a/methods/matching/find_pairs.py +++ b/methods/matching/find_pairs.py @@ -61,11 +61,11 @@ def find_match_iteration( logging.info("Preparing s_set...") m_dist_thresholded_df = m_set[DISTANCE_COLUMNS] / thresholds_for_columns - k_subset_dist_thresholded_df = k_subset[DISTANCE_COLUMNS] / thresholds_for_columns + k_set_dist_thresholded_df = k_set[DISTANCE_COLUMNS] / thresholds_for_columns # convert to float32 numpy arrays and make them contiguous for numba to vectorise m_dist_thresholded = np.ascontiguousarray(m_dist_thresholded_df, dtype=np.float32) - k_subset_dist_thresholded = np.ascontiguousarray(k_subset_dist_thresholded_df, dtype=np.float32) + k_set_dist_thresholded = np.ascontiguousarray(k_set_dist_thresholded_df, dtype=np.float32) # LUC columns are all named with the year in, so calculate the column names # for the years we are intested in @@ -78,19 +78,18 @@ def find_match_iteration( # similar to the above, make the hard match columns contiguous float32 numpy arrays m_dist_hard = np.ascontiguousarray(m_set[hard_match_columns].to_numpy()).astype(np.int32) - k_subset_dist_hard = np.ascontiguousarray(k_subset[hard_match_columns].to_numpy()).astype(np.int32) + k_set_dist_hard = np.ascontiguousarray(k_set[hard_match_columns].to_numpy()).astype(np.int32) - # Methodology 6.5.5: S should be 10 times the size of K, in order to achieve this for every - # pixel in the subsample (which is 10% the size of K) we select 100 pixels. - required = 100 + # Methodology 6.5.5: S should be 10 times the size of K + required = 10 logging.info("Running make_s_set_mask... required: %d", required) - starting_positions = rng.integers(0, int(m_dist_thresholded.shape[0]), int(k_subset_dist_thresholded.shape[0])) + starting_positions = rng.integers(0, int(m_dist_thresholded.shape[0]), int(k_set_dist_thresholded.shape[0])) s_set_mask_true, no_potentials = make_s_set_mask( m_dist_thresholded, - k_subset_dist_thresholded, + k_set_dist_thresholded, m_dist_hard, - k_subset_dist_hard, + k_set_dist_hard, starting_positions, required ) @@ -176,22 +175,22 @@ def find_match_iteration( @jit(nopython=True, fastmath=True, error_model="numpy") def make_s_set_mask( m_dist_thresholded: np.ndarray, - k_subset_dist_thresholded: np.ndarray, + k_set_dist_thresholded: np.ndarray, m_dist_hard: np.ndarray, - k_subset_dist_hard: np.ndarray, + k_set_dist_hard: np.ndarray, starting_positions: np.ndarray, required: int ): + k_size = k_set_dist_thresholded.shape[0] m_size = m_dist_thresholded.shape[0] - k_size = k_subset_dist_thresholded.shape[0] s_include = np.zeros(m_size, dtype=np.bool_) k_miss = np.zeros(k_size, dtype=np.bool_) for k in range(k_size): matches = 0 - k_row = k_subset_dist_thresholded[k, :] - k_hard = k_subset_dist_hard[k] + k_row = k_set_dist_thresholded[k, :] + k_hard = k_set_dist_hard[k] for index in range(m_size): m_index = (index + starting_positions[k]) % m_size From 4dd4f92210bb2bc15c663e28a0825659cd379ff1 Mon Sep 17 00:00:00 2001 From: Robin Message Date: Fri, 1 Dec 2023 11:40:34 +0000 Subject: [PATCH 02/12] WIP: try binning K and M by hard_match_columns; doesn't help as most of K and M are in same bin --- methods/matching/find_pairs.py | 65 ++++++++++++++++++---- methods/matching/find_potential_matches.py | 9 ++- 2 files changed, 61 insertions(+), 13 deletions(-) diff --git a/methods/matching/find_pairs.py b/methods/matching/find_pairs.py index 2f24524..385b6a0 100644 --- a/methods/matching/find_pairs.py +++ b/methods/matching/find_pairs.py @@ -1,4 +1,5 @@ import argparse +from collections import defaultdict import os import logging from functools import partial @@ -8,8 +9,9 @@ import pandas as pd from methods.common.luc import luc_matching_columns +from methods.matching.find_potential_matches import key_builder -REPEAT_MATCH_FINDING = 100 +REPEAT_MATCH_FINDING = 1 DEFAULT_DISTANCE = 10000000.0 DEBUG = False @@ -62,10 +64,12 @@ def find_match_iteration( m_dist_thresholded_df = m_set[DISTANCE_COLUMNS] / thresholds_for_columns k_set_dist_thresholded_df = k_set[DISTANCE_COLUMNS] / thresholds_for_columns + + # TODO: Split these into bins which require only looking in a couple to find a match # convert to float32 numpy arrays and make them contiguous for numba to vectorise - m_dist_thresholded = np.ascontiguousarray(m_dist_thresholded_df, dtype=np.float32) - k_set_dist_thresholded = np.ascontiguousarray(k_set_dist_thresholded_df, dtype=np.float32) + #m_dist_thresholded = np.ascontiguousarray(m_dist_thresholded_df, dtype=np.float32) + #k_set_dist_thresholded = np.ascontiguousarray(k_set_dist_thresholded_df, dtype=np.float32) # LUC columns are all named with the year in, so calculate the column names # for the years we are intested in @@ -73,23 +77,62 @@ def find_match_iteration( # As well as all the LUC columns for later use luc_columns = [x for x in m_set.columns if x.startswith('luc')] - hard_match_columns = ['country', 'ecoregion', luc10, luc5, luc0] + hard_match_columns = ['ecoregion', 'country', luc0, luc5, luc10] # This must match the order given in key_builder assert len(hard_match_columns) == HARD_COLUMN_COUNT + build_key = key_builder(start_year) # similar to the above, make the hard match columns contiguous float32 numpy arrays - m_dist_hard = np.ascontiguousarray(m_set[hard_match_columns].to_numpy()).astype(np.int32) - k_set_dist_hard = np.ascontiguousarray(k_set[hard_match_columns].to_numpy()).astype(np.int32) + #m_dist_hard = np.ascontiguousarray(m_set[hard_match_columns].to_numpy()).astype(np.int32) + #k_set_dist_hard = np.ascontiguousarray(k_set[hard_match_columns].to_numpy()).astype(np.int32) + + def make_bins(rows, normalised_rows): + bins = defaultdict(lambda : []) + for i, row in rows.iterrows(): + bins[build_key(row)].append(normalised_rows.iloc[i]) + return {k: np.array(bin, dtype=np.float32) for k, bin in bins.items()} + + logging.info("|K| total %d |M| total %d", len(k_set), len(m_set)) + k_normalised_hard_bins = make_bins(k_set, k_set_dist_thresholded_df) + + def select_bins(rows, normalised_rows, k_bins): + m_bins = {} + for k in k_bins.keys(): + column_values = build_key.lookup(k) + print(column_values) + print(rows.iloc[0:5]) + print(rows.iloc[0:5][hard_match_columns]) + matches = np.all(rows[hard_match_columns] == column_values, axis=1) + print(matches) + print(np.sum(matches)) + m_bins[k] = normalised_rows[matches] + logging.info("Bin %a |K|: %d |M|: %d", column_values, len(k_bins[k]), len(m_bins[k])) + + m_normalised_hard_bins = select_bins(m_set, m_dist_thresholded_df, k_normalised_hard_bins) + + exit() + + for k, values in k_normalised_hard_bins: + if k not in m_normalised_hard_bins: + m_normalised_hard_bins[k] = np.empty((0, len(DISTANCE_COLUMNS))) + logging.info("No matches for bin of size %d with params %a", len(values), build_key.lookup(k)) + else: + logging.info("|K| %d |S| %d for bin with params %a", len(values), len(m_normalised_hard_bins[k]), build_key.lookup(k)) + + for k, values in m_normalised_hard_bins: + if k not in k_normalised_hard_bins: + logging.info("Eliminated M bin of size %d with params %a", len(values), build_key.lookup(k)) + exit() # Methodology 6.5.5: S should be 10 times the size of K required = 10 + # TODO: From here, most of this code can run per-bin in k + logging.info("Running make_s_set_mask... required: %d", required) - starting_positions = rng.integers(0, int(m_dist_thresholded.shape[0]), int(k_set_dist_thresholded.shape[0])) + starting_positions = {(k, rng.integers(0, int(m_normalised_hard_bins[k].shape[0], k_normalised_hard_bins[k].shape[0]))) for k, _ in k_normalised_hard_bins} s_set_mask_true, no_potentials = make_s_set_mask( - m_dist_thresholded, - k_set_dist_thresholded, - m_dist_hard, - k_set_dist_hard, + m_normalised_hard_bins, + k_normalised_hard_bins, starting_positions, required ) diff --git a/methods/matching/find_potential_matches.py b/methods/matching/find_potential_matches.py index f69865a..556c36f 100644 --- a/methods/matching/find_potential_matches.py +++ b/methods/matching/find_potential_matches.py @@ -27,6 +27,7 @@ def build_key(ecoregion, country, luc0, luc5, luc10): """Create a 64-bit key for fields that must match exactly""" + return (int(ecoregion) << 32) | (int(country) << 16) | (int(luc0) << 10) | (int(luc5) << 5) | (int(luc10)) if ecoregion < 0 or ecoregion > 0x7fffffff: raise ValueError("Ecoregion doesn't fit in 31 bits") if country < 0 or country > 0xffff: @@ -37,12 +38,16 @@ def build_key(ecoregion, country, luc0, luc5, luc10): raise ValueError("luc5 doesn't fit in 5 bits") if luc10 < 0 or luc10 > 0x1f: raise ValueError("luc10 doesn't fit in 5 bits") - return (int(ecoregion) << 32) | (int(country) << 16) | (int(luc0) << 10) | (int(luc5) << 5) | (int(luc10)) def key_builder(start_year: int): luc0, luc5, luc10 = luc_matching_columns(start_year) + lookup = {} def _build_key(row): - return build_key(row.ecoregion, row.country, row[luc0], row[luc5], row[luc10]) + value = build_key(row.ecoregion, row.country, row[luc0], row[luc5], row[luc10]) + if value not in lookup: + lookup[value] = (row.ecoregion, row.country, row[luc0], row[luc5], row[luc10]) + return value + _build_key.lookup = lambda key: lookup[key] return _build_key def load_k( From c3b86df44846a185c8601d2da9ed9a9b78007335 Mon Sep 17 00:00:00 2001 From: Robin Message Date: Fri, 1 Dec 2023 11:41:08 +0000 Subject: [PATCH 03/12] Revert "WIP: try binning K and M by hard_match_columns; doesn't help as most of K and M are in same bin" This reverts commit 4dd4f92210bb2bc15c663e28a0825659cd379ff1. --- methods/matching/find_pairs.py | 65 ++++------------------ methods/matching/find_potential_matches.py | 9 +-- 2 files changed, 13 insertions(+), 61 deletions(-) diff --git a/methods/matching/find_pairs.py b/methods/matching/find_pairs.py index 385b6a0..2f24524 100644 --- a/methods/matching/find_pairs.py +++ b/methods/matching/find_pairs.py @@ -1,5 +1,4 @@ import argparse -from collections import defaultdict import os import logging from functools import partial @@ -9,9 +8,8 @@ import pandas as pd from methods.common.luc import luc_matching_columns -from methods.matching.find_potential_matches import key_builder -REPEAT_MATCH_FINDING = 1 +REPEAT_MATCH_FINDING = 100 DEFAULT_DISTANCE = 10000000.0 DEBUG = False @@ -64,12 +62,10 @@ def find_match_iteration( m_dist_thresholded_df = m_set[DISTANCE_COLUMNS] / thresholds_for_columns k_set_dist_thresholded_df = k_set[DISTANCE_COLUMNS] / thresholds_for_columns - - # TODO: Split these into bins which require only looking in a couple to find a match # convert to float32 numpy arrays and make them contiguous for numba to vectorise - #m_dist_thresholded = np.ascontiguousarray(m_dist_thresholded_df, dtype=np.float32) - #k_set_dist_thresholded = np.ascontiguousarray(k_set_dist_thresholded_df, dtype=np.float32) + m_dist_thresholded = np.ascontiguousarray(m_dist_thresholded_df, dtype=np.float32) + k_set_dist_thresholded = np.ascontiguousarray(k_set_dist_thresholded_df, dtype=np.float32) # LUC columns are all named with the year in, so calculate the column names # for the years we are intested in @@ -77,62 +73,23 @@ def find_match_iteration( # As well as all the LUC columns for later use luc_columns = [x for x in m_set.columns if x.startswith('luc')] - hard_match_columns = ['ecoregion', 'country', luc0, luc5, luc10] # This must match the order given in key_builder + hard_match_columns = ['country', 'ecoregion', luc10, luc5, luc0] assert len(hard_match_columns) == HARD_COLUMN_COUNT - build_key = key_builder(start_year) # similar to the above, make the hard match columns contiguous float32 numpy arrays - #m_dist_hard = np.ascontiguousarray(m_set[hard_match_columns].to_numpy()).astype(np.int32) - #k_set_dist_hard = np.ascontiguousarray(k_set[hard_match_columns].to_numpy()).astype(np.int32) - - def make_bins(rows, normalised_rows): - bins = defaultdict(lambda : []) - for i, row in rows.iterrows(): - bins[build_key(row)].append(normalised_rows.iloc[i]) - return {k: np.array(bin, dtype=np.float32) for k, bin in bins.items()} - - logging.info("|K| total %d |M| total %d", len(k_set), len(m_set)) - k_normalised_hard_bins = make_bins(k_set, k_set_dist_thresholded_df) - - def select_bins(rows, normalised_rows, k_bins): - m_bins = {} - for k in k_bins.keys(): - column_values = build_key.lookup(k) - print(column_values) - print(rows.iloc[0:5]) - print(rows.iloc[0:5][hard_match_columns]) - matches = np.all(rows[hard_match_columns] == column_values, axis=1) - print(matches) - print(np.sum(matches)) - m_bins[k] = normalised_rows[matches] - logging.info("Bin %a |K|: %d |M|: %d", column_values, len(k_bins[k]), len(m_bins[k])) - - m_normalised_hard_bins = select_bins(m_set, m_dist_thresholded_df, k_normalised_hard_bins) - - exit() - - for k, values in k_normalised_hard_bins: - if k not in m_normalised_hard_bins: - m_normalised_hard_bins[k] = np.empty((0, len(DISTANCE_COLUMNS))) - logging.info("No matches for bin of size %d with params %a", len(values), build_key.lookup(k)) - else: - logging.info("|K| %d |S| %d for bin with params %a", len(values), len(m_normalised_hard_bins[k]), build_key.lookup(k)) - - for k, values in m_normalised_hard_bins: - if k not in k_normalised_hard_bins: - logging.info("Eliminated M bin of size %d with params %a", len(values), build_key.lookup(k)) + m_dist_hard = np.ascontiguousarray(m_set[hard_match_columns].to_numpy()).astype(np.int32) + k_set_dist_hard = np.ascontiguousarray(k_set[hard_match_columns].to_numpy()).astype(np.int32) - exit() # Methodology 6.5.5: S should be 10 times the size of K required = 10 - # TODO: From here, most of this code can run per-bin in k - logging.info("Running make_s_set_mask... required: %d", required) - starting_positions = {(k, rng.integers(0, int(m_normalised_hard_bins[k].shape[0], k_normalised_hard_bins[k].shape[0]))) for k, _ in k_normalised_hard_bins} + starting_positions = rng.integers(0, int(m_dist_thresholded.shape[0]), int(k_set_dist_thresholded.shape[0])) s_set_mask_true, no_potentials = make_s_set_mask( - m_normalised_hard_bins, - k_normalised_hard_bins, + m_dist_thresholded, + k_set_dist_thresholded, + m_dist_hard, + k_set_dist_hard, starting_positions, required ) diff --git a/methods/matching/find_potential_matches.py b/methods/matching/find_potential_matches.py index 556c36f..f69865a 100644 --- a/methods/matching/find_potential_matches.py +++ b/methods/matching/find_potential_matches.py @@ -27,7 +27,6 @@ def build_key(ecoregion, country, luc0, luc5, luc10): """Create a 64-bit key for fields that must match exactly""" - return (int(ecoregion) << 32) | (int(country) << 16) | (int(luc0) << 10) | (int(luc5) << 5) | (int(luc10)) if ecoregion < 0 or ecoregion > 0x7fffffff: raise ValueError("Ecoregion doesn't fit in 31 bits") if country < 0 or country > 0xffff: @@ -38,16 +37,12 @@ def build_key(ecoregion, country, luc0, luc5, luc10): raise ValueError("luc5 doesn't fit in 5 bits") if luc10 < 0 or luc10 > 0x1f: raise ValueError("luc10 doesn't fit in 5 bits") + return (int(ecoregion) << 32) | (int(country) << 16) | (int(luc0) << 10) | (int(luc5) << 5) | (int(luc10)) def key_builder(start_year: int): luc0, luc5, luc10 = luc_matching_columns(start_year) - lookup = {} def _build_key(row): - value = build_key(row.ecoregion, row.country, row[luc0], row[luc5], row[luc10]) - if value not in lookup: - lookup[value] = (row.ecoregion, row.country, row[luc0], row[luc5], row[luc10]) - return value - _build_key.lookup = lambda key: lookup[key] + return build_key(row.ecoregion, row.country, row[luc0], row[luc5], row[luc10]) return _build_key def load_k( From 4589c10ae2e9a6e85f290ed92560873727b3ba27 Mon Sep 17 00:00:00 2001 From: Robin Message Date: Fri, 1 Dec 2023 17:27:47 +0000 Subject: [PATCH 04/12] WIP: using RTree to speed up search for possible random M candidates for S --- methods/matching/find_pairs.py | 260 ++++++++++++++++++++++++++++++--- 1 file changed, 237 insertions(+), 23 deletions(-) diff --git a/methods/matching/find_pairs.py b/methods/matching/find_pairs.py index 2f24524..91719db 100644 --- a/methods/matching/find_pairs.py +++ b/methods/matching/find_pairs.py @@ -3,13 +3,15 @@ import logging from functools import partial from multiprocessing import Pool, cpu_count, set_start_method -from numba import jit # type: ignore +from numba import jit, float32, int64, deferred_type # type: ignore +from numba.experimental import jitclass + import numpy as np import pandas as pd from methods.common.luc import luc_matching_columns -REPEAT_MATCH_FINDING = 100 +REPEAT_MATCH_FINDING = 1 DEFAULT_DISTANCE = 10000000.0 DEBUG = False @@ -35,8 +37,8 @@ def find_match_iteration( logging.info("Loading K from %s", k_parquet_filename) - # Methodology 6.5.7: For a 10% sample of K k_set = pd.read_parquet(k_parquet_filename) + # Methodology 6.5.7: For a 10% sample of K k_subset = k_set.sample( frac=0.1, random_state=rng @@ -62,6 +64,17 @@ def find_match_iteration( m_dist_thresholded_df = m_set[DISTANCE_COLUMNS] / thresholds_for_columns k_set_dist_thresholded_df = k_set[DISTANCE_COLUMNS] / thresholds_for_columns + # IDEA: Maybe we can bin these somehow? + + # Rearrange columns by variance so we throw out the least likely to match first + # except the bottom three which are deforestation CPCs and have more cross-variance between K and M + variances = np.std(m_dist_thresholded_df, axis=0) + cols = DISTANCE_COLUMNS + order = np.argsort(-variances.to_numpy()) + order = np.roll(order, 3) + new_cols = [cols[o] for o in order] + m_dist_thresholded_df = m_dist_thresholded_df[new_cols] + k_set_dist_thresholded_df = k_set_dist_thresholded_df[new_cols] # convert to float32 numpy arrays and make them contiguous for numba to vectorise m_dist_thresholded = np.ascontiguousarray(m_dist_thresholded_df, dtype=np.float32) @@ -91,7 +104,8 @@ def find_match_iteration( m_dist_hard, k_set_dist_hard, starting_positions, - required + required, + rng ) logging.info("Done make_s_set_mask. s_set_mask.shape: %a", {s_set_mask_true.shape}) @@ -99,6 +113,7 @@ def find_match_iteration( s_set = m_set[s_set_mask_true] potentials = np.invert(no_potentials) + # FIXME: Not sure this line is meaningful any more if potentials drawn from K? k_subset = k_subset[potentials] logging.info("Finished preparing s_set. shape: %a", {s_set.shape}) @@ -172,8 +187,208 @@ def find_match_iteration( logging.info("Finished find match iteration") -@jit(nopython=True, fastmath=True, error_model="numpy") +@jitclass +class RTree: + def __init__(): + pass + + def contains(self, range) -> bool: + raise NotImplemented() + + def depth(self) -> int: + return 1 + + def size(self) -> int: + return 1 + + def members(self, range) -> [np.ndarray]: + raise NotImplemented() + + def dump(self, space: str): + raise NotImplemented() + +@jitclass([('point', float32[:]), ('index', int64)]) +class RLeaf:#(RTree): + def __init__(self, point, index): + self.point = point + self.index = index + def contains(self, range) -> bool: + return np.all(range[0] <= self.point) & np.all(range[1] >= self.point) # type: ignore + def members(self, range): + if self.contains(range): + return np.array([self.index]) + return np.empty(0, dtype=np.int_) + def dump(self, space: str): + print(space, f"point {self.point}") + +@jitclass([('points', float32[:, :]), ('indexes', int64[:])]) +class RList:#(RTree): + def __init__(self, points, indexes): + self.points = points + self.indexes = indexes + def contains(self, range) -> bool: + return np.any(np.all(range[0] <= self.points, axis=1) & np.all(range[1] >= self.points, axis=1)) # type: ignore + def members(self, range): + return self.indexes[np.all(range[0] <= self.points, axis=1) & np.all(range[1] >= self.points, axis=1)] + def dump(self, space: str): + print(space, f"points {self.points}") + +node_type = deferred_type() +@jitclass([('d', int64), ('value', float32), ('left', node_type), ('right', node_type), ('width', int64)]) +class RSplit:#(RTree): + def __init__(self, d: int, value: float, left: RTree, right: RTree, width: int): + self.d = d + self.value = value + self.left = left + self.right = right + self.width = width + def contains(self, range) -> bool: + l = self.value - range[0, self.d] # Amount on left side + r = range[1, self.d] - self.value # Amount on right side + # Either l or r must be positive, or both + # Pick the biggest first + if l >= r: + if self.left.contains(range): + return True + # Visit the rest if it is inside + if r >= 0: + if self.right.contains(range): + return True + else: + if self.right.contains(range): + return True + # Visit the rest if it is inside + if l >= 0: + if self.left.contains(range): + return True + return False + + def members(self, range): + l = self.value - range[0, self.d] # Amount on left side + r = range[1, self.d] - self.value # Amount on right side + result = None + if l >= 0: + result = self.left.members(range) + if r >= 0: + rights = self.right.members(range) + if result is None: + result = rights + else: + result = np.append(result, rights, axis=0) + return result if result is not None else np.empty(0, dtype=np.int_) + + def size(self) -> int: + return 1 + self.left.size() + self.right.size() + + def depth(self) -> int: + return 1 + max(self.left.depth(), self.right.depth()) + + def dump(self, space: str): + print(space, f"split d{self.d} at {self.value}") + print(space + " <") + self.left.dump(space + "\t") + print(space + " >") + self.right.dump(space + "\t") + +node_type.define(RTree.class_type.instance_type) + +class RWrapper:#(RTree): + def __init__(self, tree, widths): + self.tree = tree + self.widths = widths + def contains(self, point) -> bool: + return self.tree.contains(np.array([point - self.widths, point + self.widths])) + def members(self, point) -> bool: + return self.tree.members(np.array([point - self.widths, point + self.widths])) + def dump(self, space: str): + self.tree.dump(space) + def size(self): + return self.tree.size() + def depth(self): + return self.tree.depth() + +def make_rtree(points): + def make_rtree_internal(points, indexes): + if len(points) == 1: + return RLeaf(points[0], indexes[0]) + if len(points) < 30: + return RList(points, indexes) + # Find split in dimension with most bins + dimensions = points.shape[1] + bins = None + chosen_d_min = 0 + chosen_d_max = 0 + chosen_d = 0 + for d in range(dimensions): + d_max = np.max(points[:, d]) + d_min = np.min(points[:, d]) + d_range = d_max - d_min + d_bins = d_range + if bins == None or d_bins > bins: + bins = d_bins + chosen_d = d + chosen_d_max = d_max + chosen_d_min = d_min + + if bins < 1.3: + # No split is very worthwhile, so dump points + return RList(points, indexes) + + split_at = np.median(points[:, chosen_d]) + # Avoid degenerate cases + if split_at == chosen_d_max or split_at == chosen_d_min: + split_at = (chosen_d_max + chosen_d_min) / 2 + + left_side = points[:, chosen_d] <= split_at + right_side = ~left_side + lefts = points[left_side] + rights = points[right_side] + lefts_indexes = indexes[left_side] + rights_indexes = indexes[right_side] + return RSplit(chosen_d, split_at, make_rtree_internal(lefts, lefts_indexes), make_rtree_internal(rights, rights_indexes), dimensions) + indexes = np.arange(len(points)) + return RWrapper(make_rtree_internal(points, indexes), np.ones(points.shape[1])) + def make_s_set_mask( + m_dist_thresholded: np.ndarray, + k_set_dist_thresholded: np.ndarray, + m_dist_hard: np.ndarray, + k_set_dist_hard: np.ndarray, + starting_positions: np.ndarray, + required: int, + rng: np.random.Generator +): + # Make a k-d tree for m_dist_thresholded + # Ignore dist_hard for now... + m_tree = make_rtree(m_dist_thresholded) + logging.info("Size: %d", m_tree.size()) + logging.info("Depth: %d", m_tree.depth()) + + k_size = k_set_dist_thresholded.shape[0] + m_size = m_dist_thresholded.shape[0] + + s_include = np.zeros(m_size, dtype=np.bool_) + k_miss = np.zeros(k_size, dtype=np.bool_) + + for k in range(k_set_dist_thresholded.shape[0]): + k_row = k_set_dist_thresholded[k] + possible_s = m_tree.members(k_row) + if len(possible_s) == 0: + k_miss[k] = True + else: + samples = min(len(possible_s), required) + chosen_s = rng.choice(possible_s, samples) + if chosen_s.dtype != np.int_: + print(possible_s) + print(chosen_s) + print(chosen_s.dtype) + s_include[chosen_s] = True + + return s_include, k_miss + + +@jit(nopython=True, fastmath=True, error_model="numpy") +def make_s_set_mask_old( m_dist_thresholded: np.ndarray, k_set_dist_thresholded: np.ndarray, m_dist_hard: np.ndarray, @@ -199,27 +414,26 @@ def make_s_set_mask( m_hard = m_dist_hard[m_index] should_include = True - - # check that every element of m_hard matches k_hard - hard_equals = True - for j in range(m_hard.shape[0]): - if m_hard[j] != k_hard[j]: - hard_equals = False - - if not hard_equals: - should_include = False - else: + + if should_include: for j in range(m_row.shape[0]): if abs(m_row[j] - k_row[j]) > 1.0: should_include = False - - if should_include: - s_include[m_index] = True - matches += 1 - - # Don't find any more M's - if matches == required: - break + break + + if should_include: + for j in range(m_hard.shape[0]): + if m_hard[j] != k_hard[j]: + should_include = False + break + + if should_include: + s_include[m_index] = True + matches += 1 + + # Don't find any more M's + if matches == required: + break k_miss[k] = matches == 0 From 095d8e6fa8ea3ebc09f16dbba06e4abfc1b48964 Mon Sep 17 00:00:00 2001 From: Robin Message Date: Fri, 1 Dec 2023 17:29:20 +0000 Subject: [PATCH 05/12] Remove broken numba jitclass --- methods/matching/find_pairs.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/methods/matching/find_pairs.py b/methods/matching/find_pairs.py index 91719db..b63969d 100644 --- a/methods/matching/find_pairs.py +++ b/methods/matching/find_pairs.py @@ -187,7 +187,6 @@ def find_match_iteration( logging.info("Finished find match iteration") -@jitclass class RTree: def __init__(): pass @@ -207,8 +206,7 @@ def members(self, range) -> [np.ndarray]: def dump(self, space: str): raise NotImplemented() -@jitclass([('point', float32[:]), ('index', int64)]) -class RLeaf:#(RTree): +class RLeaf(RTree): def __init__(self, point, index): self.point = point self.index = index @@ -221,8 +219,7 @@ def members(self, range): def dump(self, space: str): print(space, f"point {self.point}") -@jitclass([('points', float32[:, :]), ('indexes', int64[:])]) -class RList:#(RTree): +class RList(RTree): def __init__(self, points, indexes): self.points = points self.indexes = indexes @@ -233,9 +230,7 @@ def members(self, range): def dump(self, space: str): print(space, f"points {self.points}") -node_type = deferred_type() -@jitclass([('d', int64), ('value', float32), ('left', node_type), ('right', node_type), ('width', int64)]) -class RSplit:#(RTree): +class RSplit(RTree): def __init__(self, d: int, value: float, left: RTree, right: RTree, width: int): self.d = d self.value = value @@ -290,9 +285,7 @@ def dump(self, space: str): print(space + " >") self.right.dump(space + "\t") -node_type.define(RTree.class_type.instance_type) - -class RWrapper:#(RTree): +class RWrapper(RTree): def __init__(self, tree, widths): self.tree = tree self.widths = widths From cbac872330f346cc09f6db23326cddd425d4eac2 Mon Sep 17 00:00:00 2001 From: Robin Message Date: Sat, 2 Dec 2023 17:06:54 +0000 Subject: [PATCH 06/12] WIP: numba-ising the tree search --- methods/matching/find_pairs.py | 106 ++++++++++++++++++++++++++++++--- 1 file changed, 97 insertions(+), 9 deletions(-) diff --git a/methods/matching/find_pairs.py b/methods/matching/find_pairs.py index b63969d..4f523f9 100644 --- a/methods/matching/find_pairs.py +++ b/methods/matching/find_pairs.py @@ -1,9 +1,10 @@ import argparse +import math import os import logging from functools import partial from multiprocessing import Pool, cpu_count, set_start_method -from numba import jit, float32, int64, deferred_type # type: ignore +from numba import jit, float32, int32 # type: ignore from numba.experimental import jitclass import numpy as np @@ -200,6 +201,9 @@ def depth(self) -> int: def size(self) -> int: return 1 + def count(self) -> int: + return 0 + def members(self, range) -> [np.ndarray]: raise NotImplemented() @@ -218,6 +222,8 @@ def members(self, range): return np.empty(0, dtype=np.int_) def dump(self, space: str): print(space, f"point {self.point}") + def count(self): + return 1 class RList(RTree): def __init__(self, points, indexes): @@ -229,6 +235,8 @@ def members(self, range): return self.indexes[np.all(range[0] <= self.points, axis=1) & np.all(range[1] >= self.points, axis=1)] def dump(self, space: str): print(space, f"points {self.points}") + def count(self): + return len(self.points) class RSplit(RTree): def __init__(self, d: int, value: float, left: RTree, right: RTree, width: int): @@ -284,6 +292,87 @@ def dump(self, space: str): self.left.dump(space + "\t") print(space + " >") self.right.dump(space + "\t") + def count(self): + return self.left.count() + self.right.count() + +@jitclass([('ds', int32[:]), ('values', float32[:]), ('items', int32[:]), ('lefts', int32[:]), ('rights', int32[:]), ('rows', float32[:, :]), ('dimensions', int32)]) +class RumbaTree: + def __init__(self, ds, values, items, lefts, rights, rows, dimensions): + self.ds = ds + self.values = values + self.items = items + self.lefts = lefts + self.rights = rights + self.rows = rows + self.dimensions = dimensions + def members(self, range): + queue = [0] + finds = [] + while len(queue) > 0: + pos = queue.pop() + d = self.ds[pos] + value = self.values[pos] + if math.isnan(value): + i = d + item = self.items[i] + while item != -1: + # Check item + found = True + for d in range(self.dimensions): + value = self.rows[item, d] + if value < range[0, d]: + found = False + break + if value > range[1, d]: + found = False + break + if found: + finds.append(item) + i += 1 + item = self.items[i] + else: + if value >= range[0, d]: + queue.append(self.lefts[pos]) + if value <= range[1, d]: + queue.append(self.rights[pos]) + return finds + +NAN = float('nan') +def make_rumba_tree(tree, rows): + ds = [] + values = [] + items = [] + lefts = [] + rights = [] + def recurse(node): + if isinstance(node, RSplit): + ds.append(node.d) + values.append(node.value) + lefts.append(len(ds)) + recurse(node.left) + rights.append(len(ds)) + recurse(node.right) + elif isinstance(node, RList): + values.append(NAN) + ds.append(len(items)) + for item in node.indexes: + items.append(item) + items.append(-1) + elif isinstance(node, RLeaf): + values.append(NAN) + ds.append(len(items)) + items.append(node.index) + items.append(-1) + recurse(tree) + return RumbaTree( + np.array(ds, dtype=np.int32), + np.array(values, dtype=np.float32), + np.array(items, dtype=np.int32), + np.array(lefts, dtype=np.int32), + np.array(rights, dtype=np.int32), + rows, + rows.shape[1], + ) class RWrapper(RTree): def __init__(self, tree, widths): @@ -299,6 +388,8 @@ def size(self): return self.tree.size() def depth(self): return self.tree.depth() + def count(self): + return self.tree.count() def make_rtree(points): def make_rtree_internal(points, indexes): @@ -340,7 +431,7 @@ def make_rtree_internal(points, indexes): rights_indexes = indexes[right_side] return RSplit(chosen_d, split_at, make_rtree_internal(lefts, lefts_indexes), make_rtree_internal(rights, rights_indexes), dimensions) indexes = np.arange(len(points)) - return RWrapper(make_rtree_internal(points, indexes), np.ones(points.shape[1])) + return RWrapper(make_rumba_tree(make_rtree_internal(points, indexes), points), np.ones(points.shape[1])) def make_s_set_mask( m_dist_thresholded: np.ndarray, @@ -354,8 +445,9 @@ def make_s_set_mask( # Make a k-d tree for m_dist_thresholded # Ignore dist_hard for now... m_tree = make_rtree(m_dist_thresholded) - logging.info("Size: %d", m_tree.size()) - logging.info("Depth: %d", m_tree.depth()) + #logging.info("Size: %d", m_tree.size()) + #logging.info("Depth: %d", m_tree.depth()) + #logging.info("Points: %d Len: %d", m_tree.count(), len(m_dist_thresholded)) k_size = k_set_dist_thresholded.shape[0] m_size = m_dist_thresholded.shape[0] @@ -370,11 +462,7 @@ def make_s_set_mask( k_miss[k] = True else: samples = min(len(possible_s), required) - chosen_s = rng.choice(possible_s, samples) - if chosen_s.dtype != np.int_: - print(possible_s) - print(chosen_s) - print(chosen_s.dtype) + chosen_s = rng.choice(possible_s, samples, replace=False) s_include[chosen_s] = True return s_include, k_miss From 1d71a0c4c9c0aeff4d5759b9f318221c96f97329 Mon Sep 17 00:00:00 2001 From: Robin Message Date: Mon, 4 Dec 2023 14:55:35 +0000 Subject: [PATCH 07/12] WIP: working rumba tree --- data | 1 + inputs | 1 + methods/matching/find_pairs.py | 49 +++++++++++++++++++++++----------- secrets/rhm31.cap | 1 + 4 files changed, 36 insertions(+), 16 deletions(-) create mode 120000 data create mode 120000 inputs create mode 100644 secrets/rhm31.cap diff --git a/data b/data new file mode 120000 index 0000000..f453638 --- /dev/null +++ b/data @@ -0,0 +1 @@ +../testing/calculate_k/data \ No newline at end of file diff --git a/inputs b/inputs new file mode 120000 index 0000000..11101dd --- /dev/null +++ b/inputs @@ -0,0 +1 @@ +../testing/calculate_k/inputs \ No newline at end of file diff --git a/methods/matching/find_pairs.py b/methods/matching/find_pairs.py index 4f523f9..fd7564b 100644 --- a/methods/matching/find_pairs.py +++ b/methods/matching/find_pairs.py @@ -267,12 +267,15 @@ def contains(self, range) -> bool: return False def members(self, range): + #print(f"Node d:{self.d} value:{self.value}") l = self.value - range[0, self.d] # Amount on left side r = range[1, self.d] - self.value # Amount on right side result = None if l >= 0: + #print(f" <-") result = self.left.members(range) if r >= 0: + #print(f" ->") rights = self.right.members(range) if result is None: result = rights @@ -297,7 +300,7 @@ def count(self): @jitclass([('ds', int32[:]), ('values', float32[:]), ('items', int32[:]), ('lefts', int32[:]), ('rights', int32[:]), ('rows', float32[:, :]), ('dimensions', int32)]) class RumbaTree: - def __init__(self, ds, values, items, lefts, rights, rows, dimensions): + def __init__(self, ds: np.ndarray, values: np.ndarray, items: np.ndarray, lefts: np.ndarray, rights: np.ndarray, rows: np.ndarray, dimensions: int): self.ds = ds self.values = values self.items = items @@ -305,7 +308,9 @@ def __init__(self, ds, values, items, lefts, rights, rows, dimensions): self.rights = rights self.rows = rows self.dimensions = dimensions - def members(self, range): + def members(self, point: np.ndarray): + low = point - 1 + high = point + 1 queue = [0] finds = [] while len(queue) > 0: @@ -313,6 +318,7 @@ def members(self, range): d = self.ds[pos] value = self.values[pos] if math.isnan(value): + #print(f"Pos {pos} search_range: {search_range}") i = d item = self.items[i] while item != -1: @@ -320,25 +326,30 @@ def members(self, range): found = True for d in range(self.dimensions): value = self.rows[item, d] - if value < range[0, d]: + if value < low[d]: found = False break - if value > range[1, d]: + if value > high[d]: found = False break + #print(f"search_range: {search_range}") + #print(f"Item {item} found:{found} row:{self.rows[item]}") if found: finds.append(item) i += 1 item = self.items[i] else: - if value >= range[0, d]: - queue.append(self.lefts[pos]) - if value <= range[1, d]: + #print(f"Pos {pos} d:{d} value:{value}") + if value <= high[d]: + #print(f" -> {self.rights[pos]}") queue.append(self.rights[pos]) + if value >= low[d]: + #print(f" <- {self.lefts[pos]}") + queue.append(self.lefts[pos]) return finds NAN = float('nan') -def make_rumba_tree(tree, rows): +def make_rumba_tree(tree: RTree, rows: np.ndarray): ds = [] values = [] items = [] @@ -346,21 +357,27 @@ def make_rumba_tree(tree, rows): rights = [] def recurse(node): if isinstance(node, RSplit): + pos = len(ds) ds.append(node.d) values.append(node.value) - lefts.append(len(ds)) + lefts.append(pos + 1) # Next node we output will be left + rights.append(0xDEADBEEF) # Put placeholder here recurse(node.left) - rights.append(len(ds)) + rights[pos] = len(ds) # Fixup right to be the next node we output recurse(node.right) elif isinstance(node, RList): values.append(NAN) ds.append(len(items)) + lefts.append(-1) # Specific wrong values for spotting incorrect tree build + rights.append(-2) for item in node.indexes: items.append(item) items.append(-1) elif isinstance(node, RLeaf): values.append(NAN) ds.append(len(items)) + lefts.append(-3) + rights.append(-4) items.append(node.index) items.append(-1) recurse(tree) @@ -431,7 +448,7 @@ def make_rtree_internal(points, indexes): rights_indexes = indexes[right_side] return RSplit(chosen_d, split_at, make_rtree_internal(lefts, lefts_indexes), make_rtree_internal(rights, rights_indexes), dimensions) indexes = np.arange(len(points)) - return RWrapper(make_rumba_tree(make_rtree_internal(points, indexes), points), np.ones(points.shape[1])) + return RWrapper(make_rtree_internal(points, indexes), np.ones(points.shape[1])) def make_s_set_mask( m_dist_thresholded: np.ndarray, @@ -445,9 +462,7 @@ def make_s_set_mask( # Make a k-d tree for m_dist_thresholded # Ignore dist_hard for now... m_tree = make_rtree(m_dist_thresholded) - #logging.info("Size: %d", m_tree.size()) - #logging.info("Depth: %d", m_tree.depth()) - #logging.info("Points: %d Len: %d", m_tree.count(), len(m_dist_thresholded)) + rumba_tree = make_rumba_tree(m_tree.tree, m_dist_thresholded) k_size = k_set_dist_thresholded.shape[0] m_size = m_dist_thresholded.shape[0] @@ -455,15 +470,17 @@ def make_s_set_mask( s_include = np.zeros(m_size, dtype=np.bool_) k_miss = np.zeros(k_size, dtype=np.bool_) - for k in range(k_set_dist_thresholded.shape[0]): + for k in range(k_size): k_row = k_set_dist_thresholded[k] - possible_s = m_tree.members(k_row) + possible_s = rumba_tree.members(k_row) if len(possible_s) == 0: k_miss[k] = True + #logging.info("MISS %d of %d", k, k_size) else: samples = min(len(possible_s), required) chosen_s = rng.choice(possible_s, samples, replace=False) s_include[chosen_s] = True + #logging.info("%d of %d found %d, picked %a", k, k_size, len(possible_s), chosen_s) return s_include, k_miss diff --git a/secrets/rhm31.cap b/secrets/rhm31.cap new file mode 100644 index 0000000..b86ac15 --- /dev/null +++ b/secrets/rhm31.cap @@ -0,0 +1 @@ +capnp://sha-256:SgebTxjjj9SY29G4_I-cjSNyOBmHYq2LDYwNGb7JXWE@/tmp/ocurrent.sock/ITvMU-BpgDAdGZ63klCBtJ_dHYI \ No newline at end of file From 1751fd1a9e5813563874aa7e2248d07b7a0df4f8 Mon Sep 17 00:00:00 2001 From: Robin Message Date: Wed, 6 Dec 2023 16:30:09 +0000 Subject: [PATCH 08/12] Pull KDTree and RumbaTree into utils and add test --- methods/matching/find_pairs.py | 273 +-------------------------------- methods/utils/kd_tree.py | 268 ++++++++++++++++++++++++++++++++ test/test_kd_tree.py | 152 ++++++++++++++++++ 3 files changed, 424 insertions(+), 269 deletions(-) create mode 100644 methods/utils/kd_tree.py create mode 100644 test/test_kd_tree.py diff --git a/methods/matching/find_pairs.py b/methods/matching/find_pairs.py index fd7564b..cb5369d 100644 --- a/methods/matching/find_pairs.py +++ b/methods/matching/find_pairs.py @@ -1,16 +1,15 @@ import argparse -import math import os import logging from functools import partial from multiprocessing import Pool, cpu_count, set_start_method -from numba import jit, float32, int32 # type: ignore -from numba.experimental import jitclass import numpy as np import pandas as pd +from numba import jit from methods.common.luc import luc_matching_columns +from methods.utils.kd_tree import make_kdrangetree, make_rumba_tree REPEAT_MATCH_FINDING = 1 DEFAULT_DISTANCE = 10000000.0 @@ -188,268 +187,6 @@ def find_match_iteration( logging.info("Finished find match iteration") -class RTree: - def __init__(): - pass - - def contains(self, range) -> bool: - raise NotImplemented() - - def depth(self) -> int: - return 1 - - def size(self) -> int: - return 1 - - def count(self) -> int: - return 0 - - def members(self, range) -> [np.ndarray]: - raise NotImplemented() - - def dump(self, space: str): - raise NotImplemented() - -class RLeaf(RTree): - def __init__(self, point, index): - self.point = point - self.index = index - def contains(self, range) -> bool: - return np.all(range[0] <= self.point) & np.all(range[1] >= self.point) # type: ignore - def members(self, range): - if self.contains(range): - return np.array([self.index]) - return np.empty(0, dtype=np.int_) - def dump(self, space: str): - print(space, f"point {self.point}") - def count(self): - return 1 - -class RList(RTree): - def __init__(self, points, indexes): - self.points = points - self.indexes = indexes - def contains(self, range) -> bool: - return np.any(np.all(range[0] <= self.points, axis=1) & np.all(range[1] >= self.points, axis=1)) # type: ignore - def members(self, range): - return self.indexes[np.all(range[0] <= self.points, axis=1) & np.all(range[1] >= self.points, axis=1)] - def dump(self, space: str): - print(space, f"points {self.points}") - def count(self): - return len(self.points) - -class RSplit(RTree): - def __init__(self, d: int, value: float, left: RTree, right: RTree, width: int): - self.d = d - self.value = value - self.left = left - self.right = right - self.width = width - def contains(self, range) -> bool: - l = self.value - range[0, self.d] # Amount on left side - r = range[1, self.d] - self.value # Amount on right side - # Either l or r must be positive, or both - # Pick the biggest first - if l >= r: - if self.left.contains(range): - return True - # Visit the rest if it is inside - if r >= 0: - if self.right.contains(range): - return True - else: - if self.right.contains(range): - return True - # Visit the rest if it is inside - if l >= 0: - if self.left.contains(range): - return True - return False - - def members(self, range): - #print(f"Node d:{self.d} value:{self.value}") - l = self.value - range[0, self.d] # Amount on left side - r = range[1, self.d] - self.value # Amount on right side - result = None - if l >= 0: - #print(f" <-") - result = self.left.members(range) - if r >= 0: - #print(f" ->") - rights = self.right.members(range) - if result is None: - result = rights - else: - result = np.append(result, rights, axis=0) - return result if result is not None else np.empty(0, dtype=np.int_) - - def size(self) -> int: - return 1 + self.left.size() + self.right.size() - - def depth(self) -> int: - return 1 + max(self.left.depth(), self.right.depth()) - - def dump(self, space: str): - print(space, f"split d{self.d} at {self.value}") - print(space + " <") - self.left.dump(space + "\t") - print(space + " >") - self.right.dump(space + "\t") - def count(self): - return self.left.count() + self.right.count() - -@jitclass([('ds', int32[:]), ('values', float32[:]), ('items', int32[:]), ('lefts', int32[:]), ('rights', int32[:]), ('rows', float32[:, :]), ('dimensions', int32)]) -class RumbaTree: - def __init__(self, ds: np.ndarray, values: np.ndarray, items: np.ndarray, lefts: np.ndarray, rights: np.ndarray, rows: np.ndarray, dimensions: int): - self.ds = ds - self.values = values - self.items = items - self.lefts = lefts - self.rights = rights - self.rows = rows - self.dimensions = dimensions - def members(self, point: np.ndarray): - low = point - 1 - high = point + 1 - queue = [0] - finds = [] - while len(queue) > 0: - pos = queue.pop() - d = self.ds[pos] - value = self.values[pos] - if math.isnan(value): - #print(f"Pos {pos} search_range: {search_range}") - i = d - item = self.items[i] - while item != -1: - # Check item - found = True - for d in range(self.dimensions): - value = self.rows[item, d] - if value < low[d]: - found = False - break - if value > high[d]: - found = False - break - #print(f"search_range: {search_range}") - #print(f"Item {item} found:{found} row:{self.rows[item]}") - if found: - finds.append(item) - i += 1 - item = self.items[i] - else: - #print(f"Pos {pos} d:{d} value:{value}") - if value <= high[d]: - #print(f" -> {self.rights[pos]}") - queue.append(self.rights[pos]) - if value >= low[d]: - #print(f" <- {self.lefts[pos]}") - queue.append(self.lefts[pos]) - return finds - -NAN = float('nan') -def make_rumba_tree(tree: RTree, rows: np.ndarray): - ds = [] - values = [] - items = [] - lefts = [] - rights = [] - def recurse(node): - if isinstance(node, RSplit): - pos = len(ds) - ds.append(node.d) - values.append(node.value) - lefts.append(pos + 1) # Next node we output will be left - rights.append(0xDEADBEEF) # Put placeholder here - recurse(node.left) - rights[pos] = len(ds) # Fixup right to be the next node we output - recurse(node.right) - elif isinstance(node, RList): - values.append(NAN) - ds.append(len(items)) - lefts.append(-1) # Specific wrong values for spotting incorrect tree build - rights.append(-2) - for item in node.indexes: - items.append(item) - items.append(-1) - elif isinstance(node, RLeaf): - values.append(NAN) - ds.append(len(items)) - lefts.append(-3) - rights.append(-4) - items.append(node.index) - items.append(-1) - recurse(tree) - return RumbaTree( - np.array(ds, dtype=np.int32), - np.array(values, dtype=np.float32), - np.array(items, dtype=np.int32), - np.array(lefts, dtype=np.int32), - np.array(rights, dtype=np.int32), - rows, - rows.shape[1], - ) - -class RWrapper(RTree): - def __init__(self, tree, widths): - self.tree = tree - self.widths = widths - def contains(self, point) -> bool: - return self.tree.contains(np.array([point - self.widths, point + self.widths])) - def members(self, point) -> bool: - return self.tree.members(np.array([point - self.widths, point + self.widths])) - def dump(self, space: str): - self.tree.dump(space) - def size(self): - return self.tree.size() - def depth(self): - return self.tree.depth() - def count(self): - return self.tree.count() - -def make_rtree(points): - def make_rtree_internal(points, indexes): - if len(points) == 1: - return RLeaf(points[0], indexes[0]) - if len(points) < 30: - return RList(points, indexes) - # Find split in dimension with most bins - dimensions = points.shape[1] - bins = None - chosen_d_min = 0 - chosen_d_max = 0 - chosen_d = 0 - for d in range(dimensions): - d_max = np.max(points[:, d]) - d_min = np.min(points[:, d]) - d_range = d_max - d_min - d_bins = d_range - if bins == None or d_bins > bins: - bins = d_bins - chosen_d = d - chosen_d_max = d_max - chosen_d_min = d_min - - if bins < 1.3: - # No split is very worthwhile, so dump points - return RList(points, indexes) - - split_at = np.median(points[:, chosen_d]) - # Avoid degenerate cases - if split_at == chosen_d_max or split_at == chosen_d_min: - split_at = (chosen_d_max + chosen_d_min) / 2 - - left_side = points[:, chosen_d] <= split_at - right_side = ~left_side - lefts = points[left_side] - rights = points[right_side] - lefts_indexes = indexes[left_side] - rights_indexes = indexes[right_side] - return RSplit(chosen_d, split_at, make_rtree_internal(lefts, lefts_indexes), make_rtree_internal(rights, rights_indexes), dimensions) - indexes = np.arange(len(points)) - return RWrapper(make_rtree_internal(points, indexes), np.ones(points.shape[1])) - def make_s_set_mask( m_dist_thresholded: np.ndarray, k_set_dist_thresholded: np.ndarray, @@ -461,8 +198,8 @@ def make_s_set_mask( ): # Make a k-d tree for m_dist_thresholded # Ignore dist_hard for now... - m_tree = make_rtree(m_dist_thresholded) - rumba_tree = make_rumba_tree(m_tree.tree, m_dist_thresholded) + m_tree = make_kdrangetree(m_dist_thresholded, np.ones(m_dist_thresholded.shape[1])) + rumba_tree = make_rumba_tree(m_tree, m_dist_thresholded) k_size = k_set_dist_thresholded.shape[0] m_size = m_dist_thresholded.shape[0] @@ -475,12 +212,10 @@ def make_s_set_mask( possible_s = rumba_tree.members(k_row) if len(possible_s) == 0: k_miss[k] = True - #logging.info("MISS %d of %d", k, k_size) else: samples = min(len(possible_s), required) chosen_s = rng.choice(possible_s, samples, replace=False) s_include[chosen_s] = True - #logging.info("%d of %d found %d, picked %a", k, k_size, len(possible_s), chosen_s) return s_include, k_miss diff --git a/methods/utils/kd_tree.py b/methods/utils/kd_tree.py new file mode 100644 index 0000000..8d459dc --- /dev/null +++ b/methods/utils/kd_tree.py @@ -0,0 +1,268 @@ +import math +from typing import List + +import numpy as np +from numba import float32, int32 # type: ignore +from numba.experimental import jitclass + + +class KDTree: + def __init__(self): + pass + + def contains(self, _range) -> bool: + raise NotImplemented() + + def depth(self) -> int: + return 1 + + def size(self) -> int: + return 1 + + def count(self) -> int: + return 0 + + def members(self, _range) -> np.ndarray: + raise NotImplemented() + + def dump(self, _space: str): + raise NotImplemented() + +class KDLeaf(KDTree): + def __init__(self, point, index): + self.point = point + self.index = index + def contains(self, range) -> bool: + return np.all(range[0] <= self.point) & np.all(range[1] >= self.point) # type: ignore + def members(self, range): + if self.contains(range): + return np.array([self.index]) + return np.empty(0, dtype=np.int_) + def dump(self, space: str): + print(space, f"point {self.point}") + def count(self): + return 1 + +class KDList(KDTree): + def __init__(self, points, indexes): + self.points = points + self.indexes = indexes + def contains(self, range) -> bool: + return np.any(np.all(range[0] <= self.points, axis=1) & np.all(range[1] >= self.points, axis=1)) # type: ignore + def members(self, range): + return self.indexes[np.all(range[0] <= self.points, axis=1) & np.all(range[1] >= self.points, axis=1)] + def dump(self, space: str): + print(space, f"points {self.points}") + def count(self): + return len(self.points) + +class KDSplit(KDTree): + def __init__(self, d: int, value: float, left: KDTree, right: KDTree): + self.d = d + self.value = value + self.left = left + self.right = right + def contains(self, range) -> bool: + l = self.value - range[0, self.d] # Amount on left side + r = range[1, self.d] - self.value # Amount on right side + # Either l or r must be positive, or both + # Pick the biggest first + if l >= r: + if self.left.contains(range): + return True + # Visit the rest if it is inside + if r >= 0: + if self.right.contains(range): + return True + else: + if self.right.contains(range): + return True + # Visit the rest if it is inside + if l >= 0: + if self.left.contains(range): + return True + return False + + def members(self, range): + result = None + if self.value >= range[0, self.d]: + result = self.left.members(range) + if range[1, self.d] >= self.value: + rights = self.right.members(range) + if result is None: + result = rights + else: + result = np.append(result, rights, axis=0) + return result if result is not None else np.empty(0, dtype=np.int_) + + def size(self) -> int: + return 1 + self.left.size() + self.right.size() + + def depth(self) -> int: + return 1 + max(self.left.depth(), self.right.depth()) + + def dump(self, space: str): + print(space, f"split d{self.d} at {self.value}") + print(space + " <") + self.left.dump(space + "\t") + print(space + " >") + self.right.dump(space + "\t") + def count(self): + return self.left.count() + self.right.count() + + +class KDRangeTree: + def __init__(self, tree, widths): + self.tree = tree + self.widths = widths + def contains(self, point) -> bool: + return self.tree.contains(np.array([point - self.widths, point + self.widths])) + def members(self, point) -> np.ndarray: + return self.tree.members(np.array([point - self.widths, point + self.widths])) + def dump(self, space: str): + self.tree.dump(space) + def size(self): + return self.tree.size() + def depth(self): + return self.tree.depth() + def count(self): + return self.tree.count() + +def make_kdrangetree(points, widths): + def make_kdtree_internal(points, indexes): + if len(points) == 1: + return KDLeaf(points[0], indexes[0]) + if len(points) < 30: + return KDList(points, indexes) + # Find split in dimension with most bins + dimensions = points.shape[1] + bins: float = None # type: ignore + chosen_d_min = 0 + chosen_d_max = 0 + chosen_d = 0 + for d in range(dimensions): + d_max = np.max(points[:, d]) + d_min = np.min(points[:, d]) + d_range = d_max - d_min + d_bins = d_range / widths[d] + if bins is None or d_bins > bins: + bins = d_bins + chosen_d = d + chosen_d_max = d_max + chosen_d_min = d_min + + if bins < 1.3: + # No split is very worthwhile, so dump points + return KDList(points, indexes) + + split_at = np.median(points[:, chosen_d]) + # Avoid degenerate cases + if split_at == chosen_d_max or split_at == chosen_d_min: + split_at = (chosen_d_max + chosen_d_min) / 2 + + left_side = points[:, chosen_d] <= split_at + right_side = ~left_side + lefts = points[left_side] + rights = points[right_side] + lefts_indexes = indexes[left_side] + rights_indexes = indexes[right_side] + return KDSplit(chosen_d, split_at, make_kdtree_internal(lefts, lefts_indexes), make_kdtree_internal(rights, rights_indexes)) + indexes = np.arange(len(points)) + return KDRangeTree(make_kdtree_internal(points, indexes), widths) + +@jitclass([('ds', int32[:]), ('values', float32[:]), ('items', int32[:]), ('lefts', int32[:]), ('rights', int32[:]), ('rows', float32[:, :]), ('dimensions', int32), ('widths', float32[:])]) +class RumbaTree: + def __init__(self, ds: np.ndarray, values: np.ndarray, items: np.ndarray, lefts: np.ndarray, rights: np.ndarray, rows: np.ndarray, dimensions: int, widths: np.ndarray): + self.ds = ds + self.values = values + self.items = items + self.lefts = lefts + self.rights = rights + self.rows = rows + self.dimensions = dimensions + self.widths = widths + def members(self, point: np.ndarray): + low = point - self.widths + high = point + self.widths + queue = [0] + finds = [] + while len(queue) > 0: + pos = queue.pop() + d = self.ds[pos] + value = self.values[pos] + if math.isnan(value): + i = d + item = self.items[i] + while item != -1: + # Check item + found = True + for d in range(self.dimensions): + value = self.rows[item, d] + if value < low[d]: + found = False + break + if value > high[d]: + found = False + break + if found: + finds.append(item) + i += 1 + item = self.items[i] + else: + if value <= high[d]: + queue.append(self.rights[pos]) + if value >= low[d]: + queue.append(self.lefts[pos]) + return finds + +NAN = float('nan') +def make_rumba_tree(tree: KDRangeTree, rows: np.ndarray): + ds = [] + values = [] + items = [] + lefts = [] + rights = [] + widths = None + def recurse(node): + nonlocal widths + if isinstance(node, KDSplit): + pos = len(ds) + ds.append(node.d) + values.append(node.value) + lefts.append(pos + 1) # Next node we output will be left + rights.append(0xDEADBEEF) # Put placeholder here... + recurse(node.left) + rights[pos] = len(ds) # ..and fixup right to be the next node we output + recurse(node.right) + elif isinstance(node, KDList): + values.append(NAN) + ds.append(len(items)) + lefts.append(-1) # Specific invalid values for debugging an errors in tree build + rights.append(-2) + for item in node.indexes: + items.append(item) + items.append(-1) + elif isinstance(node, KDLeaf): + values.append(NAN) + ds.append(len(items)) + lefts.append(-3) + rights.append(-4) + items.append(node.index) + items.append(-1) + elif isinstance(node, KDRangeTree): + widths = node.widths + recurse(node.tree) + recurse(tree) + if widths is None: + raise ValueError(f"Expected KDRangeTree, got {tree}") + return RumbaTree( + np.array(ds, dtype=np.int32), + np.array(values, dtype=np.float32), + np.array(items, dtype=np.int32), + np.array(lefts, dtype=np.int32), + np.array(rights, dtype=np.int32), + np.ascontiguousarray(rows, dtype=np.float32), + rows.shape[1], + np.ascontiguousarray(widths, dtype=np.float32), + ) + diff --git a/test/test_kd_tree.py b/test/test_kd_tree.py new file mode 100644 index 0000000..5ad6955 --- /dev/null +++ b/test/test_kd_tree.py @@ -0,0 +1,152 @@ +import math +from time import time +import numpy as np +import pandas as pd + +from methods.common.luc import luc_matching_columns +from methods.utils.kd_tree import KDRangeTree, KDTree, make_kdrangetree, make_rumba_tree + +ALLOWED_VARIATION = np.array([ + 200, + 2.5, + 10, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, +]) + +def test_kd_tree_matches_as_expected(): + def build_rects(items): + rects = [] + for item in items: + lefts = [] + rights = [] + for dimension, value in enumerate(item): + width = ALLOWED_VARIATION[dimension] + if width < 0: + fraction = -width + width = value * fraction + lefts.append(value - width) + rights.append(value + width) + rects.append([lefts, rights]) + return np.array(rects) + + expected_fraction = 1 / 100 # This proportion of pixels we end up matching + + def build_kdranged_tree_for_k(k_rows) -> KDRangeTree: + return make_kdrangetree(np.array([( + row.elevation, + row.slope, + row.access, + row["cpc0_u"], + row["cpc0_d"], + row["cpc5_u"], + row["cpc5_d"], + row["cpc10_u"], + row["cpc10_d"], + ) for row in k_rows + ]), ALLOWED_VARIATION) + + + luc0, luc5, luc10 = luc_matching_columns(2012) + source_pixels = pd.read_parquet("./test/data/1201-k.parquet") + + # Split source_pixels into classes + source_rows = [] + for _, row in source_pixels.iterrows(): + key = (int(row.ecoregion) << 16) | (int(row[luc0]) << 10) | (int(row[luc5]) << 5) | (int(row[luc10])) + if key != 1967137: + continue + source_rows.append(row) + + source = np.array([ + [ + row.elevation, + row.slope, + row.access, + row["cpc0_u"], + row["cpc0_d"], + row["cpc5_u"], + row["cpc5_d"], + row["cpc10_u"], + row["cpc10_d"], + ] for row in source_rows + ]) + + # Invent an array of values that matches the expected_fraction + length = 10000 + np.random.seed(42) + + ranges = np.transpose(np.array([ + np.min(source, axis=0), + np.max(source, axis=0) + ])) + + # Safe ranges (exclude 10% of outliers) + safe_ranges = np.transpose(np.array([ + np.quantile(source, 0.05, axis=0), + np.quantile(source, 0.95, axis=0) + ])) + + # Need to put an estimate here of how much of the area inside those 90% bounds is actually filled + filled_fraction = 0.775 + + # Proportion of values that should fall inside each dimension + inside_fraction = expected_fraction * math.pow(1 / filled_fraction, len(ranges)) + inside_length = math.ceil(length * inside_fraction) + inside_values = np.random.uniform(safe_ranges[:, 0], safe_ranges[:, 1], (inside_length, len(ranges))) + + widths = ranges[:, 1] - ranges[:, 0] + range_extension = 100 * widths # Width extension makes it very unlikely a random value will be inside + outside_ranges = np.transpose([ranges[:, 0] - range_extension, ranges[:, 1] + range_extension]) + + outside_length = length - inside_length + outside_values = np.random.uniform(outside_ranges[:, 0], outside_ranges[:, 1], (outside_length, len(ranges))) + + test_values = np.append(inside_values, outside_values, axis=0) + + def do_np_matching(): + source_rects = build_rects(source) + found = 0 + for i in range(length): + pos = np.all((test_values[i] >= source_rects[:, 0]) & (test_values[i] <= source_rects[:, 1]), axis=1) + found += 1 if np.any(pos) else 0 + return found + + def speed_of(what, func): + expected_finds = 946 + start = time() + value = func() + end = time() + assert value == expected_finds, f"Got wrong value {value} for method {what}, expected {expected_finds}" + print(what, ": ", (end - start) / length, "per call") + + print("making tree... (this will take a few seconds)") + start = time() + kd_tree = build_kdranged_tree_for_k(source_rows) + print("build time", time() - start) + print("tree depth", kd_tree.depth()) + print("tree size", kd_tree.size()) + + def do_kdrange_tree_matching(): + found = 0 + for i in range(length): + found += 1 if len(kd_tree.members(test_values[i])) > 0 else 0 + return found + + rumba_tree = make_rumba_tree(kd_tree, source) + + def do_rumba_tree_matching(): + found = 0 + for i in range(length): + found += 1 if len(rumba_tree.members(test_values[i])) > 0 else 0 + return found + + test_np_matching = False # This is slow but a useful check so I don't want to delete it + if test_np_matching: + speed_of("NP matching", do_np_matching) + speed_of("KD Tree matching", do_kdrange_tree_matching) + speed_of("Rumba matching", do_rumba_tree_matching) From d51ebdea2f6503a679c53f8807ac45b4bb192012 Mon Sep 17 00:00:00 2001 From: Robin Message Date: Wed, 6 Dec 2023 17:15:59 +0000 Subject: [PATCH 09/12] Finished faster find_pairs --- methods/matching/find_pairs.py | 117 ++++++++++----------------------- 1 file changed, 35 insertions(+), 82 deletions(-) diff --git a/methods/matching/find_pairs.py b/methods/matching/find_pairs.py index cb5369d..e387c7c 100644 --- a/methods/matching/find_pairs.py +++ b/methods/matching/find_pairs.py @@ -11,7 +11,7 @@ from methods.common.luc import luc_matching_columns from methods.utils.kd_tree import make_kdrangetree, make_rumba_tree -REPEAT_MATCH_FINDING = 1 +REPEAT_MATCH_FINDING = 100 DEFAULT_DISTANCE = 10000000.0 DEBUG = False @@ -38,11 +38,6 @@ def find_match_iteration( logging.info("Loading K from %s", k_parquet_filename) k_set = pd.read_parquet(k_parquet_filename) - # Methodology 6.5.7: For a 10% sample of K - k_subset = k_set.sample( - frac=0.1, - random_state=rng - ).reset_index() logging.info("Loading M from %s", m_parquet_filename) m_set = pd.read_parquet(m_parquet_filename) @@ -64,8 +59,7 @@ def find_match_iteration( m_dist_thresholded_df = m_set[DISTANCE_COLUMNS] / thresholds_for_columns k_set_dist_thresholded_df = k_set[DISTANCE_COLUMNS] / thresholds_for_columns - # IDEA: Maybe we can bin these somehow? - + # Rearrange columns by variance so we throw out the least likely to match first # except the bottom three which are deforestation CPCs and have more cross-variance between K and M variances = np.std(m_dist_thresholded_df, axis=0) @@ -89,33 +83,50 @@ def find_match_iteration( hard_match_columns = ['country', 'ecoregion', luc10, luc5, luc0] assert len(hard_match_columns) == HARD_COLUMN_COUNT - # similar to the above, make the hard match columns contiguous float32 numpy arrays - m_dist_hard = np.ascontiguousarray(m_set[hard_match_columns].to_numpy()).astype(np.int32) - k_set_dist_hard = np.ascontiguousarray(k_set[hard_match_columns].to_numpy()).astype(np.int32) + # Find categories in K + hard_match_categories = [k[hard_match_columns].to_numpy() for _, k in k_set.iterrows()] + hard_match_categories = {k.tobytes(): k for k in hard_match_categories} + no_potentials = [] # Methodology 6.5.5: S should be 10 times the size of K required = 10 logging.info("Running make_s_set_mask... required: %d", required) - starting_positions = rng.integers(0, int(m_dist_thresholded.shape[0]), int(k_set_dist_thresholded.shape[0])) - s_set_mask_true, no_potentials = make_s_set_mask( - m_dist_thresholded, - k_set_dist_thresholded, - m_dist_hard, - k_set_dist_hard, - starting_positions, - required, - rng - ) + + s_set_mask_true = np.zeros(m_set.shape[0], dtype=np.bool_) + no_potentials = np.zeros(k_set.shape[0], dtype=np.bool_) + + # Split K and M into those categories and create masks + for values in hard_match_categories.values(): + k_selector = np.all(k_set[hard_match_columns] == values, axis=1) + m_selector = np.all(m_set[hard_match_columns] == values, axis=1) + logging.info(" category: %a |K|: %d |M|: %d", values, k_selector.sum(), m_selector.sum()) + # Make masks for each of those pairs + key_s_set_mask_true, key_no_potentials = make_s_set_mask( + m_dist_thresholded[m_selector], + k_set_dist_thresholded[k_selector], + required, + rng + ) + # Merge into one s_set_mask_true + s_set_mask_true[m_selector] = key_s_set_mask_true + # Merge into no_potentials + no_potentials[k_selector] = key_no_potentials logging.info("Done make_s_set_mask. s_set_mask.shape: %a", {s_set_mask_true.shape}) s_set = m_set[s_set_mask_true] + logging.info("Finished preparing s_set. shape: %a", {s_set.shape}) potentials = np.invert(no_potentials) - # FIXME: Not sure this line is meaningful any more if potentials drawn from K? - k_subset = k_subset[potentials] - logging.info("Finished preparing s_set. shape: %a", {s_set.shape}) + # Methodology 6.5.7: For a 10% sample of K + k_subset = k_set.sample( + frac=0.1, + random_state=rng + ) + k_subset = k_subset.apply(lambda row: potentials[row.index]) + k_subset.reset_index() + logging.info("Finished preparing k_subset. shape: %a", {k_subset.shape}) # Notes: # 1. Not all pixels may have matches @@ -190,14 +201,9 @@ def find_match_iteration( def make_s_set_mask( m_dist_thresholded: np.ndarray, k_set_dist_thresholded: np.ndarray, - m_dist_hard: np.ndarray, - k_set_dist_hard: np.ndarray, - starting_positions: np.ndarray, required: int, rng: np.random.Generator ): - # Make a k-d tree for m_dist_thresholded - # Ignore dist_hard for now... m_tree = make_kdrangetree(m_dist_thresholded, np.ones(m_dist_thresholded.shape[1])) rumba_tree = make_rumba_tree(m_tree, m_dist_thresholded) @@ -216,59 +222,6 @@ def make_s_set_mask( samples = min(len(possible_s), required) chosen_s = rng.choice(possible_s, samples, replace=False) s_include[chosen_s] = True - - return s_include, k_miss - - -@jit(nopython=True, fastmath=True, error_model="numpy") -def make_s_set_mask_old( - m_dist_thresholded: np.ndarray, - k_set_dist_thresholded: np.ndarray, - m_dist_hard: np.ndarray, - k_set_dist_hard: np.ndarray, - starting_positions: np.ndarray, - required: int -): - k_size = k_set_dist_thresholded.shape[0] - m_size = m_dist_thresholded.shape[0] - - s_include = np.zeros(m_size, dtype=np.bool_) - k_miss = np.zeros(k_size, dtype=np.bool_) - - for k in range(k_size): - matches = 0 - k_row = k_set_dist_thresholded[k, :] - k_hard = k_set_dist_hard[k] - - for index in range(m_size): - m_index = (index + starting_positions[k]) % m_size - - m_row = m_dist_thresholded[m_index, :] - m_hard = m_dist_hard[m_index] - - should_include = True - - if should_include: - for j in range(m_row.shape[0]): - if abs(m_row[j] - k_row[j]) > 1.0: - should_include = False - break - - if should_include: - for j in range(m_hard.shape[0]): - if m_hard[j] != k_hard[j]: - should_include = False - break - - if should_include: - s_include[m_index] = True - matches += 1 - - # Don't find any more M's - if matches == required: - break - - k_miss[k] = matches == 0 return s_include, k_miss From 33b78b04eb600e1ed7dbc5e26fd727d4cc247c4b Mon Sep 17 00:00:00 2001 From: Robin Message Date: Thu, 7 Dec 2023 12:20:01 +0000 Subject: [PATCH 10/12] Removed accidental tree additions --- data | 1 - inputs | 1 - 2 files changed, 2 deletions(-) delete mode 120000 data delete mode 120000 inputs diff --git a/data b/data deleted file mode 120000 index f453638..0000000 --- a/data +++ /dev/null @@ -1 +0,0 @@ -../testing/calculate_k/data \ No newline at end of file diff --git a/inputs b/inputs deleted file mode 120000 index 11101dd..0000000 --- a/inputs +++ /dev/null @@ -1 +0,0 @@ -../testing/calculate_k/inputs \ No newline at end of file From 070de48f0c8e20183cf77be3d2ff103a69c61303 Mon Sep 17 00:00:00 2001 From: Robin Message Date: Wed, 13 Dec 2023 09:44:45 +0000 Subject: [PATCH 11/12] WIP: faster sampling of members of RumbaTree in find_pairs --- methods/matching/find_pairs.py | 8 ++-- methods/utils/kd_tree.py | 75 ++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 4 deletions(-) diff --git a/methods/matching/find_pairs.py b/methods/matching/find_pairs.py index e387c7c..7e679c0 100644 --- a/methods/matching/find_pairs.py +++ b/methods/matching/find_pairs.py @@ -214,14 +214,14 @@ def make_s_set_mask( k_miss = np.zeros(k_size, dtype=np.bool_) for k in range(k_size): + if (k % 100) == 0: + logging.info(f"{100 * k / k_size}% completed...") k_row = k_set_dist_thresholded[k] - possible_s = rumba_tree.members(k_row) + possible_s = rumba_tree.members_sample(k_row, required, rng) if len(possible_s) == 0: k_miss[k] = True else: - samples = min(len(possible_s), required) - chosen_s = rng.choice(possible_s, samples, replace=False) - s_include[chosen_s] = True + s_include[possible_s] = True return s_include, k_miss diff --git a/methods/utils/kd_tree.py b/methods/utils/kd_tree.py index 8d459dc..b44d3b9 100644 --- a/methods/utils/kd_tree.py +++ b/methods/utils/kd_tree.py @@ -214,6 +214,81 @@ def members(self, point: np.ndarray): if value >= low[d]: queue.append(self.lefts[pos]) return finds + def count_members(self, point: np.ndarray): + low = point - self.widths + high = point + self.widths + queue = [0] + count = 0 + while len(queue) > 0: + pos = queue.pop() + d = self.ds[pos] + value = self.values[pos] + if math.isnan(value): + i = d + item = self.items[i] + while item != -1: + # Check item + found = True + for d in range(self.dimensions): + value = self.rows[item, d] + if value < low[d]: + found = False + break + if value > high[d]: + found = False + break + if found: + count += 1 + i += 1 + item = self.items[i] + else: + if value <= high[d]: + queue.append(self.rights[pos]) + if value >= low[d]: + queue.append(self.lefts[pos]) + return count + def members_sample(self, point: np.ndarray, count: int, rng: np.random.Generator): + low = point - self.widths + high = point + self.widths + queue = [0] + finds = [] + rand = rng.integers(0, 2**32) + while len(queue) > 0: + pos = queue.pop() + d = self.ds[pos] + value = self.values[pos] + if math.isnan(value): + i = d + item = self.items[i] + while item != -1: + # Check item + found = True + for d in range(self.dimensions): + value = self.rows[item, d] + if value < low[d]: + found = False + break + if value > high[d]: + found = False + break + if found: + if len(finds) < count: + finds.append(item) + else: + # Replace a random item in finds based on on-line search probability + pos = rand % len(finds) + rand *= 65539 + rand &= 0x7FFFFFFF + if pos < count: + finds[pos] = item + i += 1 + item = self.items[i] + else: + if value <= high[d]: + queue.append(self.rights[pos]) + if value >= low[d]: + queue.append(self.lefts[pos]) + return finds NAN = float('nan') def make_rumba_tree(tree: KDRangeTree, rows: np.ndarray): From 8bd315fc9e596f078b2fc65bc05dd5691c6b066b Mon Sep 17 00:00:00 2001 From: Robin Message Date: Wed, 13 Dec 2023 13:57:56 +0000 Subject: [PATCH 12/12] Split m_tree into multiple trees if M>>K*required. Also output matchless correctly --- methods/matching/find_pairs.py | 38 +++++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/methods/matching/find_pairs.py b/methods/matching/find_pairs.py index 7e679c0..c1f7d96 100644 --- a/methods/matching/find_pairs.py +++ b/methods/matching/find_pairs.py @@ -1,4 +1,5 @@ import argparse +import math import os import logging from functools import partial @@ -124,7 +125,7 @@ def find_match_iteration( frac=0.1, random_state=rng ) - k_subset = k_subset.apply(lambda row: potentials[row.index]) + k_subset = k_subset[k_subset.apply(lambda row: potentials[row.name], axis=1)] k_subset.reset_index() logging.info("Finished preparing k_subset. shape: %a", {k_subset.shape}) @@ -204,20 +205,41 @@ def make_s_set_mask( required: int, rng: np.random.Generator ): - m_tree = make_kdrangetree(m_dist_thresholded, np.ones(m_dist_thresholded.shape[1])) - rumba_tree = make_rumba_tree(m_tree, m_dist_thresholded) - k_size = k_set_dist_thresholded.shape[0] m_size = m_dist_thresholded.shape[0] s_include = np.zeros(m_size, dtype=np.bool_) k_miss = np.zeros(k_size, dtype=np.bool_) + m_sets = max(1, min(100, math.floor(m_size // 1e6), math.ceil(m_size / (k_size * required * 10)))) + + m_lookup = np.arange(m_size) + rng.shuffle(m_lookup) + m_step = math.ceil(m_size / m_sets) + + def m_index(m_set: int, pos: int): + return m_lookup[m_set * m_step + pos] + def m_indexes(m_set: int): + return m_lookup[m_set * m_step:(m_set + 1) * m_step] + + m_trees = [make_kdrangetree(m_dist_thresholded[m_indexes(m_set)], np.ones(m_dist_thresholded.shape[1])) for m_set in range(m_sets)] + + rumba_trees = [make_rumba_tree(m_tree, m_dist_thresholded) for m_tree in m_trees] + for k in range(k_size): - if (k % 100) == 0: - logging.info(f"{100 * k / k_size}% completed...") k_row = k_set_dist_thresholded[k] - possible_s = rumba_tree.members_sample(k_row, required, rng) + m_order = np.arange(m_sets) + rng.shuffle(m_order) + possible_s = [] + for m_set in m_order: + next_possible_s = rumba_trees[m_set].members_sample(k_row, required, rng) + if possible_s is None: + possible_s = [m_index(m_set, s) for s in next_possible_s] + else: + take = min(required - len(possible_s), len(next_possible_s)) + possible_s[len(possible_s):len(possible_s)+take] = [m_index(m_set, s) for s in next_possible_s[0:take]] + if len(possible_s) == required: + break if len(possible_s) == 0: k_miss[k] = True else: @@ -277,6 +299,8 @@ def greedy_match( results.append((k_idx, min_dist_idx)) s_available[min_dist_idx] = False total_available -= 1 + else: + matchless.append(k_idx) else: matchless.append(k_idx)