Tutorial

Example in 2D Uniform Distribution

This example generates a 2D random uniform distribution, and then uses GriSPy to search neighbors within a given radius and/or the n-nearest neighbors


Import GriSPy and others packages

[1]:
import numpy as np
import matplotlib.pyplot as plt

from grispy import GriSPy
[2]:
%matplotlib inline

Create random points and centres

[3]:
Npoints = 10 ** 3
Ncentres = 2
dim = 2
Lbox = 100.0

np.random.seed(0)
data = np.random.uniform(0, Lbox, size=(Npoints, dim))
centres = np.random.uniform(0, Lbox, size=(Ncentres, dim))

Build the grid with the data

[4]:
gsp = GriSPy(data)

Set periodicity. Periodic conditions on x-axis (or axis=0) and y-axis (or axis=1)

[5]:
periodic = {0: (0, Lbox), 1: (0, Lbox)}
gsp.set_periodicity(periodic)
[5]:
GriSPy(N_cells=20, periodic={0: (0, 100.0), 1: (0, 100.0)}, metric='euclid', copy_data=False)

Also you can build a periodic grid in the same step

[6]:
gsp = GriSPy(data, periodic=periodic)

Query for neighbors within upper_radii

[7]:
upper_radii = 10.0
bubble_dist, bubble_ind = gsp.bubble_neighbors(
    centres, distance_upper_bound=upper_radii
)

Query for neighbors in a shell within lower_radii and upper_radii

[8]:
upper_radii = 10.0
lower_radii = 8.0
shell_dist, shell_ind = gsp.shell_neighbors(
    centres,
    distance_lower_bound=lower_radii,
    distance_upper_bound=upper_radii
)

Query for nth nearest neighbors

[9]:
n_nearest = 10
near_dist, near_ind = gsp.nearest_neighbors(centres, n=n_nearest)

Plot results

[10]:
fig, axes = plt.subplots(1, 3, figsize=(14, 5))

ax = axes[0]
ax.set_title("Bubble query")
ax.scatter(data[:, 0], data[:, 1], c="k", marker=".", s=3)
for ind in bubble_ind:
    ax.scatter(data[ind, 0], data[ind, 1], c="C3", marker="o", s=5)
ax.plot(centres[:,0],centres[:,1],'ro',ms=10)


ax = axes[1]
ax.set_title("Shell query")
ax.scatter(data[:, 0], data[:, 1], c="k", marker=".", s=2)
for ind in shell_ind:
    ax.scatter(data[ind, 0], data[ind, 1], c="C2", marker="o", s=5)
ax.plot(centres[:,0],centres[:,1],'ro',ms=10)

ax = axes[2]
ax.set_title("n-Nearest query")
ax.scatter(data[:, 0], data[:, 1], c="k", marker=".", s=2)
for ind in near_ind:
    ax.scatter(data[ind, 0], data[ind, 1], c="C0", marker="o", s=5)
ax.plot(centres[:,0],centres[:,1],'ro',ms=10)

fig.tight_layout()
_images/tutorial_19_0.png

Creating your curstom distance function

Let’s assume that we intend to compare our distances using levenshtein’s metric for similarity between text (https://en.wikipedia.org/wiki/Levenshtein_distance).

Luckly we have the excellent textdistance library that implements efficiently this distance.

We can install it with

$ pip install textdistance

and then import it with

[11]:
import textdistance

So to make these custom distance compatible with GriSPy, we must define a function that receives 3 parameters:      - c0 the center to which we seek the distance. - centres the \(C\) centers to which we want to calculate the distance from    a c0. - dim the dimension of each center and c0.

Finally the function must return a np.ndarray with \(C\) elements where the element \(j-nth\) corresponds to the distance between c0 and centres\(_j\).

[12]:
def levenshtein(c0, centres, dim):
    # textdistance only operates over list and tuples
    c0 = tuple(c0)

    # creates a empty array with the required
    # number of distances
    distances = np.empty(len(centres))
    for idx, c1 in enumerate(centres):

        # textdistance only operates over list and tuples
        c1 = tuple(c1)

        # calculate the distance
        dis = textdistance.levenshtein(c0, c1)

        # store the distance
        distances[idx] = dis

    return distances

Then we create the grid with the custom distance, and run the code

[13]:
gsp = GriSPy(data, metric=levenshtein)

upper_radii = 10.0
lev_dist, lev_ind = gsp.bubble_neighbors(
    centres, distance_upper_bound=upper_radii)

Finally we can check our bubble_neighbors result with a plot

[14]:
fig, axes = plt.subplots(figsize=(6, 6))

ax = axes
ax.set_title("Bubble query with Levenshtein distance")
ax.scatter(data[:, 0], data[:, 1], c="k", marker=".", s=3)
for ind in lev_ind:
    ax.scatter(data[ind, 0], data[ind, 1], c="C3", marker="o", s=5)
ax.plot(centres[:,0],centres[:,1],'ro',ms=10)

fig.tight_layout()
_images/tutorial_28_0.png