Source code for grispy.core

#!/usr/bin/env python
# -*- coding: utf-8 -*-

# This file is part of the
#   GriSPy Project (https://github.com/mchalela/GriSPy).
# Copyright (c) 2019, Martin Chalela
# License: MIT
#   Full Text: https://github.com/mchalela/GriSPy/blob/master/LICENSE


# =============================================================================
# DOCS
# =============================================================================

"""GriSPy core class."""

# =============================================================================
# IMPORTS
# =============================================================================

import itertools

import attr
import numpy as np

from . import distances
from . import validators as vlds

# =============================================================================
# CONSTANTS
# =============================================================================

METRICS = {
    "euclid": distances.euclid,
    "haversine": distances.haversine,
    "vincenty": distances.vincenty,
}


EMPTY_ARRAY = np.array([], dtype=int)


# =============================================================================
#  PERIODICITY CONF CLASS
# =============================================================================


[docs]@attr.s(frozen=True) class PeriodicityConf: """Internal representation of the periodicity of the Grid.""" periodic_flag = attr.ib() pd_hi = attr.ib() pd_low = attr.ib() periodic_edges = attr.ib() periodic_direc = attr.ib()
# ============================================================================= # MAIN CLASS # =============================================================================
[docs]@attr.s class Grid: """Grid indexing. Grid is a regular grid indexing algorithm. This class indexes a set of k-dimensional points in a regular grid. Parameters ---------- data: ndarray, shape(n,k) The n data points of dimension k to be indexed. This array is not copied, and so modifying this data may result in erroneous results. The data can be copied if the grid is built with copy_data=True. N_cells: positive int, optional The number of cells of each dimension to build the grid. The final grid will have N_cells**k number of cells. Default: 64 copy_data: bool, optional Flag to indicate if the data should be copied in memory. Default: False Attributes ---------- dim: int The dimension of a single data-point. grid: dict This dictionary contains the data indexed in a grid. The key is a tuple with the k-dimensional index of each grid cell. Empty cells do not have a key. The value is a list of data points indices which are located within the given cell. k_bins: ndarray, shape (N_cells + 1, k) The limits of the grid cells in each dimension. edges: ndarray, shape (2, k) Grid edges or bound values. The lower and upper bounds per dimension. epsilon: float Value of increment used to create the grid edges. ndata: int Total number of a data-points. shape: tuple Number of cells per dimension. size: int Total number of cells. cell_width: ndarray Cell size in each dimension. """ # User input params data = attr.ib(default=None, kw_only=False, repr=False) N_cells = attr.ib(default=64) copy_data = attr.ib( default=False, validator=attr.validators.instance_of(bool) ) # ========================================================================= # ATTRS INITIALIZATION # ========================================================================= def __attrs_post_init__(self): """Init more params and build the grid.""" if self.copy_data: self.data = self.data.copy() self.k_bins = self._make_bins() self.grid = self._build_grid() @data.validator def _validate_data(self, attribute, value): """Validate init params: data.""" # Chek if numpy array if not isinstance(value, np.ndarray): raise TypeError( "Data: Argument must be a numpy array." "Got instead type {}".format(type(value)) ) # Check if data has the expected dimension if value.ndim != 2: raise ValueError( "Data: Array has the wrong shape. Expected shape of (n, k), " "got instead {}".format(value.shape) ) # Check if data has the expected dimension if len(value.flatten()) == 0: raise ValueError("Data: Array must have at least 1 point") # Check if every data point is valid if not np.isfinite(value).all(): raise ValueError("Data: Array must have real numbers") @N_cells.validator def _validate_N_cells(self, attr, value): """Validate init params: N_cells.""" # Chek if int if not isinstance(value, int): raise TypeError( "N_cells: Argument must be an integer. " "Got instead type {}".format(type(value)) ) # Check if N_cells is valid, i.e. higher than 1 if value < 1: raise ValueError( "N_cells: Argument must be higher than 1. " "Got instead {}".format(value) ) # ========================================================================= # PROPERTIES # ========================================================================= @property def dim(self): """Grid dimension.""" return self.data.shape[1] @property def edges(self): """Edges of the grid in each dimension.""" return self.k_bins[[0, -1], :].copy() @property def epsilon(self): """Epsilon used to expand the grid.""" # Check the resolution of the input data and increase it # two orders of magnitude. This works for float{32,64} # Fix issue #7 dtype = self.data.dtype if np.issubdtype(dtype, np.integer): return 1e-1 # assume floating return np.finfo(dtype).resolution * 1e2 @property def ndata(self): """Total number of a data-points.""" return len(self.data) @property def shape(self): """Grid shape, i.e. number of cells per dimension.""" return (self.N_cells,) * self.dim @property def size(self): """Grid size, i.e. total number of cells.""" return self.N_cells ** self.dim @property def cell_width(self): """Cell size in each dimension.""" id0 = np.zeros((1, self.dim), dtype=int) lower, upper = self.cell_walls(id0) return upper - lower # ========================================================================= # INTERNAL IMPLEMENTATION # ========================================================================= def _make_bins(self): """Return bins values.""" dmin = self.data.min(axis=0) - self.epsilon dmax = self.data.max(axis=0) + self.epsilon return np.linspace(dmin, dmax, self.N_cells + 1) def _digitize(self, data, bins): """Return data bin index.""" if bins.ndim == 1: d = (data - bins[0]) / (bins[1] - bins[0]) else: d = (data - bins[0, :]) / (bins[1, :] - bins[0, :]) # allowed indeces with int16: (-32768 to 32767) return d.astype(np.int16) def _build_grid(self): """Build the grid.""" # Digitize data points k_digit = self._digitize(self.data, self.k_bins) # Store in grid all cell neighbors compact_ind = np.ravel_multi_index( k_digit.T, self.shape, order="F", mode="clip" ) compact_ind_sort = np.argsort(compact_ind) compact_ind = compact_ind[compact_ind_sort] k_digit = k_digit[compact_ind_sort] split_ind = np.searchsorted(compact_ind, np.arange(self.size)) deleted_cells = np.diff(np.append(-1, split_ind)).astype(bool) split_ind = split_ind[deleted_cells] data_ind = np.arange(self.ndata) if split_ind[-1] > data_ind[-1]: split_ind = split_ind[:-1] list_ind = np.split(data_ind[compact_ind_sort], split_ind[1:]) k_digit = k_digit[split_ind] grid = dict() for i, j in enumerate(k_digit): grid[tuple(j)] = tuple(list_ind[i]) return grid # ========================================================================= # GRID API # =========================================================================
[docs] def contains(self, points): """Check if points are contained within the grid. Parameters ---------- points: ndarray, shape (m,k) The point or points to check against the grid domain. Returns ------- bool: ndarray, shape (m,) Boolean array indicating if a point is contained within the grid. """ # Validate inputs vlds.validate_centres(points, self.data) lower = self.edges[0, :] < points upper = self.edges[-1, :] > points return (lower & upper).prod(axis=1, dtype=bool)
[docs] def cell_digits(self, points): """Return grid cell indices for a given point. Parameters ---------- points: ndarray, shape (m,k) The point or points to calculate the cell indices. Returns ------- digits: ndarray, shape (m,k) Array of cell indices with same shape as `points`. If a point is outside of the grid edges `-1` is returned. """ # Validate inputs vlds.validate_centres(points, self.data) digits = self._digitize(points, bins=self.k_bins) # Check if outside the grid outside = ~self.contains(points) if outside.any(): digits[outside] = -1 return digits
[docs] def cell_id(self, points): """Return grid cell unique id for a given point. Parameters ---------- points: ndarray, shape (m,k) The point or points to calculate the cell unique id. Returns ------- ids: ndarray, shape (m,) Array of cell unique ids for each point. If a point is outside of the grid edges `-1` is returned. """ # Validate points vlds.validate_centres(points, self.data) digits = self._digitize(points, bins=self.k_bins) ids = np.ravel_multi_index( digits.T, self.shape, order="F", mode="clip" ) # Check if outside the grid outside = ~self.contains(points) if outside.any(): ids[outside] = -1 return ids
[docs] def cell_digits2id(self, digits): """Return unique id of cells given their digits. Parameters ---------- digits: ndarray, shape (m,k) Array of cell indices. Must be integers. Returns ------- ids: ndarray, shape (m,) Array of cell unique ids for each point. """ # Validate digits vlds.validate_digits(digits, self.N_cells) return np.ravel_multi_index( digits.T, self.shape, order="F", mode="clip" )
[docs] def cell_id2digits(self, ids): """Return cell digits given their unique id. Parameters ---------- ids: ndarray, shape (m,) Array of cell unique ids for each point. Returns ------- digits: ndarray, shape (m,k) Array of cell indices. """ # Validate ids vlds.validate_ids(ids, self.size) digits = np.unravel_index(ids, self.shape, order="F") digits = np.vstack(digits).T # Convert to int16 for consistency with _digitize return digits.astype(np.int16)
[docs] def cell_walls(self, digits): """Return cell wall coordinates for given cell digits. Parameters ---------- digits: ndarray, shape (m,k) Array of cell indices. Must be integers. Returns ------- lower: ndarray, shape (m, 3) Lower cell wall values for each point. upper: ndarray, shape (m, 3) Upper cell wall values for each point. """ # Validate digits vlds.validate_digits(digits, self.N_cells) kb = self.k_bins # get bin values for the walls lower = np.vstack([kb[digits[:, k], k] for k in range(self.dim)]).T upper = np.vstack([kb[digits[:, k] + 1, k] for k in range(self.dim)]).T return lower, upper
[docs] def cell_centre(self, digits): """Return cell centre coordinates for given cell digits. Parameters ---------- digits: ndarray, shape (m,k) Array of cell indices. Must be integers. Returns ------- centres: ndarray, shape (m, k) Cell centre for each point. """ # Validate digits vlds.validate_digits(digits, self.N_cells) lower, upper = self.cell_walls(digits) centre = (lower + upper) * 0.5 return centre
[docs] def cell_count(self, digits): """Return number of points within given cell digits. Parameters ---------- digits: ndarray, shape (m,k) Array of cell indices. Must be integers. Returns ------- count: ndarray, shape (m,) Cell count for each for each cell. """ # Validate digits vlds.validate_digits(digits, self.N_cells) get = self.grid.get counts = [len(get(tuple(dgt), ())) for dgt in digits] return np.asarray(counts)
[docs] def cell_points(self, digits): """Return indices of points within given cell digits. Parameters ---------- digits: ndarray, shape (m,k) Array of cell indices. Must be integers. Returns ------- points: list, length m List of m arrays. Each array has the indices to the neighbors of that cell. """ # Validate digits vlds.validate_digits(digits, self.N_cells) get = self.grid.get points = [np.asarray(get(tuple(dgt), ())) for dgt in digits] return points
[docs]@attr.s class GriSPy(Grid): """Grid Search in Python. GriSPy is a regular grid search algorithm for quick nearest-neighbor lookup. This class indexes a set of k-dimensional points in a regular grid providing a fast aproach for nearest neighbors queries. Optional periodic boundary conditions can be provided for each axis individually. The algorithm has the following queries implemented: - bubble_neighbors: find neighbors within a given radius. A different radius for each centre can be provided. - shell_neighbors: find neighbors within given lower and upper radius. Different lower and upper radius can be provided for each centre. - nearest_neighbors: find the nth nearest neighbors for each centre. Other methods: - set_periodicity: set periodicity condition after the grid was built. To be implemented: - box_neighbors: find neighbors within a k-dimensional squared box of a given size and orientation. - n_jobs: number of cores for parallel computation. Parameters ---------- data: ndarray, shape(n,k) The n data points of dimension k to be indexed. This array is not copied, and so modifying this data may result in erroneous results. The data can be copied if the grid is built with copy_data=True. N_cells: positive int, optional The number of cells of each dimension to build the grid. The final grid will have N_cells**k number of cells. Default: 64 copy_data: bool, optional Flag to indicate if the data should be copied in memory. Default: False periodic: dict, optional Dictionary indicating if the data domain is periodic in some or all its dimensions. The key is an integer that correspond to the number of dimensions in data, going from 0 to k-1. The value is a tuple with the domain limits and the data must be contained within these limits. If an axis is not specified, or if its value is None, it will be considered as non-periodic. Important: The periodicity only works within one periodic range. Default: all axis set to None. Example, periodic = { 0: (0, 360), 1: None}. metric: str, optional Metric definition to compute distances. Options: 'euclid', 'haversine' 'vincenty' or a custom callable. Attributes ---------- dim: int The dimension of a single data-point. grid: dict This dictionary contains the data indexed in a grid. The key is a tuple with the k-dimensional index of each grid cell. Empty cells do not have a key. The value is a list of data points indices which are located within the given cell. k_bins: ndarray, shape (N_cells + 1, k) The limits of the grid cells in each dimension. edges: ndarray, shape (2, k) Grid edges or bound values. The lower and upper bounds per dimension. epsilon: float Value of increment used to create the grid edges. ndata: int Total number of a data-points. shape: tuple Number of cells per dimension. size: int Total number of cells. cell_width: ndarray Cell size in each dimension. periodic_flag: bool If any dimension has periodicity. periodic_conf: grispy.core.PeriodicityConf Statistics and intermediate results to make easy and fast the searchs with periodicity. """ # User input params periodic = attr.ib(factory=dict) metric = attr.ib(default="euclid") # ========================================================================= # ATTRS INITIALIZATION # ========================================================================= def __attrs_post_init__(self): """Init more params and build the grid.""" super().__attrs_post_init__() self.periodic, self.periodic_conf = self._build_periodicity( periodic=self.periodic, dim=self.dim ) @metric.validator def _validate_metric(self, attr, value): """Validate init params: metric.""" # Check if name is valid if value not in METRICS and not callable(value): metric_names = ", ".join(METRICS) raise ValueError( "Metric: Got an invalid name: '{}'. " "Options are: {} or a callable".format(value, metric_names) ) @periodic.validator def _validate_periodic(self, attr, value): # Chek if dict if not isinstance(value, dict): raise TypeError( "Periodicity: Argument must be a dictionary. " "Got instead type {}".format(type(value)) ) # If dict is empty means no perioity, stop validation. if len(value) == 0: return # Check if keys and values are valid for k, v in value.items(): # Check if integer if not isinstance(k, int): raise TypeError( "Periodicity: Keys must be integers. " "Got instead type {}".format(type(k)) ) # Check if tuple or None if not (isinstance(v, tuple) or v is None): raise TypeError( "Periodicity: Values must be tuples. " "Got instead type {}".format(type(v)) ) if v is None: continue # Check if edges are valid numbers has_valid_number = all( [ isinstance(v[0], (int, float)), isinstance(v[1], (int, float)), ] ) if not has_valid_number: raise TypeError( "Periodicity: Argument must be a tuple of " "2 real numbers as edge descriptors. " ) # Check that first number is lower than second if not v[0] < v[1]: raise ValueError( "Periodicity: First argument in tuple must be " "lower than second argument." ) # ========================================================================= # PROPERTIES # ========================================================================= @property def periodic_flag(self): """Proxy to ``periodic_conf_.periodic_flag``.""" return self.periodic_conf.periodic_flag # ========================================================================= # INTERNAL IMPLEMENTATION # ========================================================================= def _build_periodicity(self, periodic, dim): """Cleanup the periodicity configuration. Remove the unnecessary axis from the periodic dict and also creates a configuration for use in the search. """ # assume no periodicity cleaned_periodic = {} periodic_flag = False pd_hi, pd_low = None, None periodic_edges, periodic_direc = None, None periodic_flag = any([x is not None for x in list(periodic.values())]) # now check if periodic if periodic_flag: pd_hi = np.ones((1, dim)) * np.inf pd_low = np.ones((1, dim)) * -np.inf periodic_edges = [] for k in range(dim): aux = periodic.get(k) cleaned_periodic[k] = aux if aux: pd_low[0, k] = aux[0] pd_hi[0, k] = aux[1] aux = np.insert(aux, 1, 0.0) else: aux = np.zeros((1, 3)) periodic_edges = np.hstack( [ periodic_edges, np.tile(aux, (3 ** (dim - 1 - k), 3 ** k)).T.ravel(), ] ) periodic_edges = periodic_edges.reshape(dim, 3 ** dim).T periodic_edges -= periodic_edges[::-1] periodic_edges = np.unique(periodic_edges, axis=0) mask = periodic_edges.sum(axis=1, dtype=bool) periodic_edges = periodic_edges[mask] periodic_direc = np.sign(periodic_edges) return cleaned_periodic, PeriodicityConf( periodic_flag=periodic_flag, pd_hi=pd_hi, pd_low=pd_low, periodic_edges=periodic_edges, periodic_direc=periodic_direc, ) def _distance(self, centre_0, centres): """Compute distance between points. metric options: 'euclid', 'sphere' Notes: In the case of 'sphere' metric, the input units must be degrees. """ if len(centres) == 0: return EMPTY_ARRAY.copy() metric_func = ( self.metric if callable(self.metric) else METRICS[self.metric] ) return metric_func(centre_0, centres, self.dim) def _get_neighbor_distance(self, centres, neighbor_cells): """Retrieve neighbor distances whithin the given cells.""" # combine the centres with the neighbors centres_ngb = zip(centres, neighbor_cells) n_idxs, n_dis = [], [] for centre, neighbors in centres_ngb: if len(neighbors) == 0: # no hay celdas vecinas n_idxs.append(EMPTY_ARRAY.copy()) n_dis.append(EMPTY_ARRAY.copy()) continue # Genera una lista con los vecinos de cada celda ind_tmp = [self.grid.get(nt, []) for nt in map(tuple, neighbors)] # Une en una sola lista todos sus vecinos inds = np.fromiter(itertools.chain(*ind_tmp), dtype=np.int32) n_idxs.append(inds) if self.dim == 1: dis = self._distance(centre, self.data[inds]) else: dis = self._distance(centre, self.data.take(inds, axis=0)) n_dis.append(dis) return n_dis, n_idxs # Neighbor-cells methods def _get_neighbor_cells( self, centres, distance_upper_bound, distance_lower_bound=0, shell_flag=False, ): """Retrieve cells touched by the search radius.""" cell_point = np.zeros((len(centres), self.dim), dtype=int) out_of_field = np.zeros(len(cell_point), dtype=bool) for k in range(self.dim): cell_point[:, k] = self._digitize( centres[:, k], bins=self.k_bins[:, k] ) out_of_field[ (centres[:, k] - distance_upper_bound > self.k_bins[-1, k]) ] = True out_of_field[ (centres[:, k] + distance_upper_bound < self.k_bins[0, k]) ] = True if np.all(out_of_field): # no neighbor cells return [EMPTY_ARRAY.copy() for _ in centres] # Armo la caja con celdas a explorar k_cell_min = np.zeros((len(centres), self.dim), dtype=int) k_cell_max = np.zeros((len(centres), self.dim), dtype=int) for k in range(self.dim): k_cell_min[:, k] = self._digitize( centres[:, k] - distance_upper_bound, bins=self.k_bins[:, k] ) k_cell_max[:, k] = self._digitize( centres[:, k] + distance_upper_bound, bins=self.k_bins[:, k] ) k_cell_min[k_cell_min[:, k] < 0, k] = 0 k_cell_max[k_cell_max[:, k] < 0, k] = 0 k_cell_min[k_cell_min[:, k] >= self.N_cells, k] = self.N_cells - 1 k_cell_max[k_cell_max[:, k] >= self.N_cells, k] = self.N_cells - 1 cell_size = self.k_bins[1, :] - self.k_bins[0, :] cell_radii = 0.5 * np.sum(cell_size ** 2) ** 0.5 neighbor_cells = [] for i, centre in enumerate(centres): # Para cada centro i, agrego un arreglo con shape (:,k) k_grids = [ np.arange(k_cell_min[i, k], k_cell_max[i, k] + 1) for k in range(self.dim) ] k_grids = np.meshgrid(*k_grids) neighbor_cells += [ np.array(list(map(np.ndarray.flatten, k_grids))).T ] # Calculo la distancia de cada centro i a sus celdas vecinas, # luego descarto las celdas que no toca el circulo definido por # la distancia cells_physical = [ self.k_bins[neighbor_cells[i][:, k], k] + 0.5 * cell_size[k] for k in range(self.dim) ] cells_physical = np.array(cells_physical).T mask_cells = ( self._distance(centre, cells_physical) < distance_upper_bound[i] + cell_radii ) if shell_flag: mask_cells *= ( self._distance(centre, cells_physical) > distance_lower_bound[i] - cell_radii ) if np.any(mask_cells): neighbor_cells[i] = neighbor_cells[i][mask_cells] else: neighbor_cells[i] = EMPTY_ARRAY.copy() return neighbor_cells def _near_boundary(self, centres, distance_upper_bound): mask = np.zeros((len(centres), self.dim), dtype=bool) for k in range(self.dim): if self.periodic[k] is None: continue mask[:, k] = ( abs(centres[:, k] - self.periodic[k][0]) < distance_upper_bound ) mask[:, k] += ( abs(centres[:, k] - self.periodic[k][1]) < distance_upper_bound ) return mask.sum(axis=1, dtype=bool) def _mirror(self, centre, distance_upper_bound): pd_hi, pd_low, periodic_edges, periodic_direc = ( self.periodic_conf.pd_hi, self.periodic_conf.pd_low, self.periodic_conf.periodic_edges, self.periodic_conf.periodic_direc, ) mirror_centre = centre - periodic_edges mask = periodic_direc * distance_upper_bound mask = mask + mirror_centre mask = (mask >= pd_low) * (mask <= pd_hi) mask = np.prod(mask, 1, dtype=bool) return mirror_centre[mask] def _mirror_universe(self, centres, distance_upper_bound): """Generate Terran centres in the Mirror Universe.""" terran_centres = np.array([[]] * self.dim).T terran_indices = np.array([], dtype=int) near_boundary = self._near_boundary(centres, distance_upper_bound) if not np.any(near_boundary): return terran_centres, terran_indices for i, centre in enumerate(centres): if not near_boundary[i]: continue mirror_centre = self._mirror(centre, distance_upper_bound[i]) if len(mirror_centre) > 0: terran_centres = np.concatenate( (terran_centres, mirror_centre), axis=0 ) terran_indices = np.concatenate( (terran_indices, np.repeat(i, len(mirror_centre))) ) return terran_centres, terran_indices # ========================================================================= # PERIODICITY # =========================================================================
[docs] def set_periodicity(self, periodic, inplace=False): """Set periodicity conditions. This allows to define or change the periodicity limits without having to construct the grid again. Important: The periodicity only works within one periodic range. Parameters ---------- periodic: dict, optional Dictionary indicating if the data domain is periodic in some or all its dimensions. The key is an integer that corresponds to the number of dimensions in data, going from 0 to k-1. The value is a tuple with the domain limits and the data must be contained within these limits. If an axis is not specified, or if its value is None, it will be considered as non-periodic. Default: all axis set to None. Example, periodic = { 0: (0, 360), 1: None}. inplace: boolean, optional (default=False) If its True, set the periodicity on the current GriSPy instance and return None. Otherwise a new instance is created and returned. """ if inplace: periodic_attr = attr.fields(GriSPy).periodic periodic_attr.validator(self, periodic_attr, periodic) self.periodic, self.periodic_conf = self._build_periodicity( periodic=periodic, dim=self.dim ) else: return GriSPy( data=self.data, N_cells=self.N_cells, metric=self.metric, copy_data=self.copy_data, periodic=periodic, )
# ========================================================================= # SEARCH API # =========================================================================
[docs] def bubble_neighbors( self, centres, distance_upper_bound=-1.0, sorted=False, kind="quicksort", ): """Find all points within given distances of each centre. Parameters ---------- centres: ndarray, shape (m,k) The point or points to search for neighbors of. distance_upper_bound: scalar or ndarray of length m The radius of points to return. If a scalar is provided, the same distance will apply for every centre. An ndarray with individual distances can also be rovided. sorted: bool, optional If True the returned neighbors will be ordered by increasing distance to the centre. Default: False. kind: str, optional When sorted = True, the sorting algorithm can be specified in this keyword. Available algorithms are: ['quicksort', 'mergesort', 'heapsort', 'stable']. Default: 'quicksort' njobs: int, optional Number of jobs for parallel computation. Not implemented yet. Returns ------- distances: list, length m Returns a list of m arrays. Each array has the distances to the neighbors of that centre. indices: list, length m Returns a list of m arrays. Each array has the indices to the neighbors of that centre. """ # Validate iputs vlds.validate_centres(centres, self.data) vlds.validate_distance_bound(distance_upper_bound, self.periodic) vlds.validate_bool(sorted) vlds.validate_sortkind(kind) # Match distance_upper_bound shape with centres shape if np.isscalar(distance_upper_bound): distance_upper_bound *= np.ones(len(centres)) else: vlds.validate_equalsize(centres, distance_upper_bound) # Get neighbors neighbor_cells = self._get_neighbor_cells( centres, distance_upper_bound ) neighbors_distances, neighbors_indices = self._get_neighbor_distance( centres, neighbor_cells ) # We need to generate mirror centres for periodic boundaries... if self.periodic_flag: terran_centres, terran_indices = self._mirror_universe( centres, distance_upper_bound ) # terran_centres are the centres in the mirror universe for those # near the boundary. terran_neighbor_cells = self._get_neighbor_cells( terran_centres, distance_upper_bound[terran_indices] ) ( terran_neighbors_distances, terran_neighbors_indices, ) = self._get_neighbor_distance( terran_centres, terran_neighbor_cells ) for i, t in zip(terran_indices, np.arange(len(terran_centres))): # i runs over normal indices that have a terran counterpart # t runs over terran indices, 0 to len(terran_centres) neighbors_distances[i] = np.concatenate( (neighbors_distances[i], terran_neighbors_distances[t]) ) neighbors_indices[i] = np.concatenate( (neighbors_indices[i], terran_neighbors_indices[t]) ) for i in range(len(centres)): mask_distances = neighbors_distances[i] <= distance_upper_bound[i] neighbors_distances[i] = neighbors_distances[i][mask_distances] neighbors_indices[i] = neighbors_indices[i][mask_distances] if sorted: sorted_ind = np.argsort(neighbors_distances[i], kind=kind) neighbors_distances[i] = neighbors_distances[i][sorted_ind] neighbors_indices[i] = neighbors_indices[i][sorted_ind] return neighbors_distances, neighbors_indices
[docs] def shell_neighbors( self, centres, distance_lower_bound=-1.0, distance_upper_bound=-1.0, sorted=False, kind="quicksort", ): """Find all points within given lower and upper distances of each centre. The distance condition is: `distance_lower_bound <= distance < distance_upper_bound` Parameters ---------- centres: ndarray, shape (m,k) The point or points to search for neighbors of. distance_lower_bound: scalar or ndarray of length m The minimum distance of points to return. If a scalar is provided, the same distance will apply for every centre. An ndarray with individual distances can also be rovided. distance_upper_bound: scalar or ndarray of length m The maximum distance of points to return. If a scalar is provided, the same distance will apply for every centre. An ndarray with individual distances can also be rovided. sorted: bool, optional If True the returned neighbors will be ordered by increasing distance to the centre. Default: False. kind: str, optional When sorted = True, the sorting algorithm can be specified in this keyword. Available algorithms are: ['quicksort', 'mergesort', 'heapsort', 'stable']. Default: 'quicksort' njobs: int, optional Number of jobs for parallel computation. Not implemented yet. Returns ------- distances: list, length m Returns a list of m arrays. Each array has the distances to the neighbors of that centre. indices: list, length m Returns a list of m arrays. Each array has the indices to the neighbors of that centre. """ # Validate inputs vlds.validate_centres(centres, self.data) vlds.validate_bool(sorted) vlds.validate_sortkind(kind) vlds.validate_shell_distances( distance_lower_bound, distance_upper_bound, self.periodic ) # Match distance bounds shapes with centres shape if np.isscalar(distance_lower_bound): distance_lower_bound *= np.ones(len(centres)) else: vlds.validate_equalsize(centres, distance_lower_bound) if np.isscalar(distance_upper_bound): distance_upper_bound *= np.ones(len(centres)) else: vlds.validate_equalsize(centres, distance_upper_bound) # Get neighbors neighbor_cells = self._get_neighbor_cells( centres, distance_upper_bound=distance_upper_bound, distance_lower_bound=distance_lower_bound, shell_flag=True, ) neighbors_distances, neighbors_indices = self._get_neighbor_distance( centres, neighbor_cells ) # We need to generate mirror centres for periodic boundaries... if self.periodic_flag: terran_centres, terran_indices = self._mirror_universe( centres, distance_upper_bound ) # terran_centres are the centres in the mirror universe for those # near the boundary. terran_neighbor_cells = self._get_neighbor_cells( terran_centres, distance_upper_bound[terran_indices] ) ( terran_neighbors_distances, terran_neighbors_indices, ) = self._get_neighbor_distance( terran_centres, terran_neighbor_cells ) for i, t in zip(terran_indices, np.arange(len(terran_centres))): # i runs over normal indices that have a terran counterpart # t runs over terran indices, 0 to len(terran_centres) neighbors_distances[i] = np.concatenate( (neighbors_distances[i], terran_neighbors_distances[t]) ) neighbors_indices[i] = np.concatenate( (neighbors_indices[i], terran_neighbors_indices[t]) ) for i in range(len(centres)): mask_distances_upper = ( neighbors_distances[i] < distance_upper_bound[i] ) mask_distances_lower = neighbors_distances[i][mask_distances_upper] mask_distances_lower = ( mask_distances_lower >= distance_lower_bound[i] ) aux = neighbors_distances[i] aux = aux[mask_distances_upper] aux = aux[mask_distances_lower] neighbors_distances[i] = aux aux = neighbors_indices[i] aux = aux[mask_distances_upper] aux = aux[mask_distances_lower] neighbors_indices[i] = aux if sorted: sorted_ind = np.argsort(neighbors_distances[i], kind=kind) neighbors_distances[i] = neighbors_distances[i][sorted_ind] neighbors_indices[i] = neighbors_indices[i][sorted_ind] return neighbors_distances, neighbors_indices
[docs] def nearest_neighbors(self, centres, n=1, kind="quicksort"): """Find the n nearest-neighbors for each centre. Parameters ---------- centres: ndarray, shape (m,k) The point or points to search for neighbors of. n: int, optional The number of neighbors to fetch for each centre. Default: 1. kind: str, optional The returned neighbors will be ordered by increasing distance to the centre. The sorting algorithm can be specified in this keyword. Available algorithms are: ['quicksort', 'mergesort', 'heapsort', 'stable']. Default: 'quicksort' njobs: int, optional Number of jobs for parallel computation. Not implemented yet. Returns ------- distances: list, length m Returns a list of m arrays. Each array has the distances to the neighbors of that centre. indices: list, length m Returns a list of m arrays. Each array has the indices to the neighbors of that centre. """ # Validate input vlds.validate_centres(centres, self.data) vlds.validate_n_nearest(n, self.data, self.periodic) vlds.validate_sortkind(kind) # Initial definitions N_centres = len(centres) centres_lookup_ind = np.arange(0, N_centres) n_found = np.zeros(N_centres, dtype=bool) lower_distance_tmp = np.zeros(N_centres) upper_distance_tmp = np.zeros(N_centres) # First estimation is the cell radii cell_size = self.k_bins[1, :] - self.k_bins[0, :] cell_radii = 0.5 * np.sum(cell_size ** 2) ** 0.5 upper_distance_tmp = cell_radii * np.ones(N_centres) neighbors_indices = [EMPTY_ARRAY.copy() for _ in range(N_centres)] neighbors_distances = [EMPTY_ARRAY.copy() for _ in range(N_centres)] while not np.all(n_found): ndis_tmp, nidx_tmp = self.shell_neighbors( centres[~n_found], distance_lower_bound=lower_distance_tmp[~n_found], distance_upper_bound=upper_distance_tmp[~n_found], ) for i_tmp, i in enumerate(centres_lookup_ind[~n_found]): if n <= len(nidx_tmp[i_tmp]) + len(neighbors_indices[i]): n_more = n - len(neighbors_indices[i]) n_found[i] = True else: n_more = len(nidx_tmp[i_tmp]) lower_distance_tmp[i] = upper_distance_tmp[i].copy() upper_distance_tmp[i] += cell_size.min() sorted_ind = np.argsort(ndis_tmp[i_tmp], kind=kind)[:n_more] neighbors_distances[i] = np.hstack( (neighbors_distances[i], ndis_tmp[i_tmp][sorted_ind]) ) neighbors_indices[i] = np.hstack( (neighbors_indices[i], nidx_tmp[i_tmp][sorted_ind]) ) return neighbors_distances, neighbors_indices