Tutorial

Example with a 2D Uniform Distribution

This example generates a 2D random uniform distribution. We will use these points to show how you can use Grid and GriSPy to index and query for neighbors.

Keep in mind that the class Grid provides the funcitonality needed to index the points. The class GriSPy inherits from Grid all its functionalities and adds the relevant methods to perform neighbors queries: set_periodicity, bubble_neighbors, shell_neighbors and nearest_neighbors. For completeness, we will show the example using GriSPy.

### Table of Contents
* Create a random distribution of points
* Index the points with Grid/GriSPy
* Creating your curstom distance function

Create a random distribution of points

Import GriSPy and others packages

[1]:
import numpy as np
import matplotlib.pyplot as plt

import grispy as gsp

Create random points and centres

[2]:
Npoints = 10 ** 3
Ncentres = 2
dim = 2
Lbox = 100.0

rng = np.random.default_rng(seed=0)
data = rng.uniform(0, Lbox, size=(Npoints, dim))
centres = rng.uniform(0, Lbox, size=(Ncentres, dim))

Index the points with Grid/GriSPy

[3]:
grid = gsp.GriSPy(data, N_cells=32)

You can get some information about the grid using some helpfull methods:

[4]:
grid.shape   # number of cells per dimension
[4]:
(32, 32)
[5]:
grid.edges   # grid edges in each dimension
[5]:
array([[2.97613690e-01, 1.90001607e-02],
       [9.99501352e+01, 9.98067457e+01]])
[6]:
digits = grid.cell_digits(centres)   # check the cell indices (or digits) where a given set of points would fall
digits
[6]:
array([[31,  1],
       [29,  8]], dtype=int16)

This means that from the 2 input points, the first would be located in the cell (31, 1) and the second in (29, 8).

[7]:
grid.cell_count(digits)   # number of points in the array `data` that where indexed at initialization
[7]:
array([2, 0])
[8]:
grid.cell_centre(digits)    # position of each cell centre
[8]:
array([[98.39306458,  4.69655073],
       [92.16478198, 26.52512006]])

The digits of cells can be used in other useful methods like cell_walls or cell_points.

Another important method is contains which shows if a given set of new points are contained by the grid.

[9]:
points = np.array([
    [30., 50.],
    [-10., -10.]
])
grid.contains(points)
[9]:
array([ True, False])

Set periodicity conditions

Set periodicity conditions on x-axis (or axis=0) and y-axis (or axis=1)

[10]:
periodic = {0: (0, Lbox), 1: (0, Lbox)}
grid.set_periodicity(periodic, inplace=True)
grid
[10]:
GriSPy(N_cells=32, copy_data=False, periodic={0: (0, 100.0), 1: (0, 100.0)}, metric='euclid')

Also you can build a periodic grid in the same step

[11]:
grid = gsp.GriSPy(data, periodic=periodic)

Important: Periodic boundaries don’t have to be exactly the same as the values of grid.edges. The edges of the grid are computed from real data to optimize the queries on sparse data. So, there is nothing wrong on the edge of the grid being at 99.9 but the periodic boundary set exactly at 100.

Query for neighbors within upper_radii

[12]:
upper_radii = 10.0
bubble_dist, bubble_ind = grid.bubble_neighbors(
    centres, distance_upper_bound=upper_radii
)

Query for neighbors in a shell within lower_radii and upper_radii

[13]:
upper_radii = 20.0
lower_radii = 10.0
shell_dist, shell_ind = grid.shell_neighbors(
    centres,
    distance_lower_bound=lower_radii,
    distance_upper_bound=upper_radii
)

Query for nth nearest neighbors

[14]:
n_nearest = 10
near_dist, near_ind = grid.nearest_neighbors(centres, n=n_nearest)

Plot results

[15]:
fig, axes = plt.subplots(1, 3, figsize=(14, 5))

ax = axes[0]
ax.set_title("Bubble query")
ax.scatter(data[:, 0], data[:, 1], c="k", marker=".", s=3)
for ind in bubble_ind:
    ax.scatter(data[ind, 0], data[ind, 1], c="C3", marker="o", s=5)
ax.plot(centres[:,0],centres[:,1],'ro',ms=10)


ax = axes[1]
ax.set_title("Shell query")
ax.scatter(data[:, 0], data[:, 1], c="k", marker=".", s=2)
for ind in shell_ind:
    ax.scatter(data[ind, 0], data[ind, 1], c="C2", marker="o", s=5)
ax.plot(centres[:,0],centres[:,1],'ro',ms=10)

ax = axes[2]
ax.set_title("n-Nearest query")
ax.scatter(data[:, 0], data[:, 1], c="k", marker=".", s=2)
for ind in near_ind:
    ax.scatter(data[ind, 0], data[ind, 1], c="C0", marker="o", s=5)
ax.plot(centres[:,0],centres[:,1],'ro',ms=10)

fig.tight_layout()
_images/tutorial_32_0.png

Creating your curstom distance function

Let’s assume that we intend to compare our distances using levenshtein’s metric for similarity between text (https://en.wikipedia.org/wiki/Levenshtein_distance).

Luckly we have the excellent textdistance library that implements efficiently this distance.

We can install it with

$ pip install textdistance

and then import it with

[16]:
import textdistance

So to make these custom distance compatible with GriSPy, we must define a function that receives 3 parameters:      - c0 the center to which we seek the distance. - centres the \(C\) centers to which we want to calculate the distance from    a c0. - dim the dimension of each center and c0.

Finally the function must return a np.ndarray with \(C\) elements where the element \(j-nth\) corresponds to the distance between c0 and centres\(_j\).

[17]:
def levenshtein(c0, centres, dim):
    # textdistance only operates over list and tuples
    c0 = tuple(c0)

    # creates a empty array with the required
    # number of distances
    distances = np.empty(len(centres))
    for idx, c1 in enumerate(centres):

        # textdistance only operates over list and tuples
        c1 = tuple(c1)

        # calculate the distance
        dis = textdistance.levenshtein(c0, c1)

        # store the distance
        distances[idx] = dis

    return distances

Then we create the grid with the custom distance, and run the code

[18]:
grid = gsp.GriSPy(data, metric=levenshtein)

upper_radii = 10.0
lev_dist, lev_ind = grid.bubble_neighbors(
    centres, distance_upper_bound=upper_radii)

Finally we can check our bubble_neighbors result with a plot

[19]:
fig, axes = plt.subplots(figsize=(6, 6))

ax = axes
ax.set_title("Bubble query with Levenshtein distance")
ax.scatter(data[:, 0], data[:, 1], c="k", marker=".", s=3)
for ind in lev_ind:
    ax.scatter(data[ind, 0], data[ind, 1], c="C3", marker="o", s=5)
ax.plot(centres[:,0],centres[:,1],'ro',ms=10)

fig.tight_layout()
_images/tutorial_41_0.png