Skip to content

Commit

Permalink
Implement delaunay edge finding
Browse files Browse the repository at this point in the history
  • Loading branch information
ksunden committed Sep 17, 2019
1 parent 097bf9b commit 23164af
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 15 deletions.
78 changes: 64 additions & 14 deletions attune/workup/_holistic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Function for processing multi-dependent tuning data."""

import itertools

import matplotlib.pyplot as plt
import numpy as np
import scipy

import WrightTools as wt

Expand All @@ -11,10 +14,12 @@
def holistic(
data,
channel,
dependents,
curve,
*,
level=False,
cutoff_factor=0.1,
gtol=0.01,
ltol=0.1,
autosave=True,
save_directory=None,
**spline_kwargs,
Expand All @@ -23,38 +28,83 @@ def holistic(
# TODO: docstring
# HACKS
data = data.copy()
opa_index = 1

axis="wa"
# TODO: check if level does what we want
if level:
data.level(channel, 0, -3)

# take channel moments
data.moment(axis="wa", channel=channel, resultant=wt.kit.joint_shape(*data.axes[:-1]), moment=0)
data.moment(axis="wa", channel=channel, resultant=wt.kit.joint_shape(*data.axes[:-1]), moment=1)
data.moment(axis=axis, channel=channel, resultant=wt.kit.joint_shape(*data.axes[:-1]), moment=0)
data.moment(axis=axis, channel=channel, resultant=wt.kit.joint_shape(*data.axes[:-1]), moment=1)
amplitudes = data.channels[-2]
centers = data.channels[-1]
data.transform(*[a for a in data.axis_expressions if a != axis])

# would be nice if this was retained in the file I'm using, but it can be hacked
C1 = curve.dependents['0']
data.create_variable("w1_Crystal_1", values=C1[:, None, None])

data.transform("w1_Crystal_1", "w1_Delay_1", "wa")

# this suprises me... Kyle can you resolve?
assert data.w1_Delay_1.shape == (25, 51, 1)
# TODO: gtol/ltol should maybe be moved to wt
cutoff = amplitudes.max() * gtol
print(cutoff)
amplitudes.clip(min=cutoff)
centers[np.isnan(amplitudes)] = np.nan
#max_axis = tuple(i for i, v in enumerate(data.axes[0].shape) if v > 1)
#cutoff = np.amax(amplitudes[:], axis=1, keepdims=True) * ltol
#amplitudes.clip(min=cutoff)



# preapre for plot
# prepare for plot
fig, gs = wt.artists.create_figure(width='single', cols=[1, 'cbar'])
cmap = wt.artists.colormaps['default']
cmap.set_bad([0.75] * 3, 1.)
cmap.set_under([0.75] * 3, 1.)

ax = plt.subplot(gs[0, 0])
cax = plt.subplot(gs[1])
X, Y, Z = wt.artists.pcolor_helper(data.axes[0].points, data.axes[1].points, amplitudes.points)
ax.pcolor(X, Y, Z)
wt.artists.plot_colorbar(cax, cmap="viridis")#, vlim=(np.nanmin(Z), np.nanmax(Z)))


# would be nice if this was retained in the file I'm using, but it can be hacked
#C1 = curve.dependents['0']
#data.create_variable("w1_Crystal_1", values=C1[:, None, None])
#data.transform("w1_Crystal_1", "w1_Delay_1", "wa")
# TODO, make sure array axis isn't counted in "full" (may need separate helper)
points = list(zip(*[a.full.flatten() for a in data.axes]))
ndim = len(data.axes)
delaunay = scipy.spatial.Delaunay(points)

amp_interp = scipy.interpolate.LinearNDInterpolator(delaunay, amplitudes.full.flatten())
cen_interp = scipy.interpolate.LinearNDInterpolator(delaunay, centers.full.flatten())

out_points = []
for p in curve.setpoints[:]:
iso_points = []
for s, pts, vals in find_simplices_containing(delaunay, cen_interp, p):
iso_points.extend(edge_intersections(pts, vals, p))
iso_points = np.array(iso_points)
if iso_points.size > 5:
ax.scatter(*iso_points.T, s=2, label=f"{p} iso points")
#out_points.append((moment(iso_points.T[i], amp_interp(iso_points)) for i in range(ndim)))

#ax.scatter(out_points, label="out_points")
#ax.legend()

plt.show()


def find_simplices_containing(delaunay, interpolator, point):
for s in delaunay.simplices:
extrema = interpolator([p for p in delaunay.points[s]])
if min(extrema) < point <= max(extrema):
yield s, delaunay.points[s], extrema

def edge_intersections(points, evaluated, target):
sortord = np.argsort(evaluated)
evaluated = evaluated[sortord]
points = points[sortord]
for (p1, p2), (v1,v2) in zip(itertools.combinations(points,2), itertools.combinations(evaluated, 2)):
if v1 < target <= v2:
yield tuple(p1[i] + (p2[i]-p1[i])*((target - v1)/(v2-v1)) for i in range(len(p1)))


def old():
Expand Down
6 changes: 5 additions & 1 deletion tests/workup/holistic/2018-11-30/bvwnu.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import matplotlib
matplotlib.use("TkAgg")

import tempfile
import attune
import WrightTools as wt
Expand All @@ -17,5 +20,6 @@
]
old = attune.TopasCurve.read(curve_paths, interaction_string='NON-NON-NON-Sig')
# do calculation
new = attune.workup.holistic(d, "array_signal", old)
d.transform("w1_Crystal_1", "w1_Delay_1", "wa")
new = attune.workup.holistic(d, "array_signal", [], old, gtol=.000, level=True)
# check

0 comments on commit 23164af

Please sign in to comment.