diff --git a/attune/workup/_holistic.py b/attune/workup/_holistic.py index 4f9674e..d67047a 100644 --- a/attune/workup/_holistic.py +++ b/attune/workup/_holistic.py @@ -1,7 +1,10 @@ """Function for processing multi-dependent tuning data.""" +import itertools import matplotlib.pyplot as plt +import numpy as np +import scipy import WrightTools as wt @@ -11,10 +14,12 @@ def holistic( data, channel, + dependents, curve, *, level=False, - cutoff_factor=0.1, + gtol=0.01, + ltol=0.1, autosave=True, save_directory=None, **spline_kwargs, @@ -23,38 +28,83 @@ def holistic( # TODO: docstring # HACKS data = data.copy() - opa_index = 1 + + axis="wa" + # TODO: check if level does what we want + if level: + data.level(channel, 0, -3) + # take channel moments - data.moment(axis="wa", channel=channel, resultant=wt.kit.joint_shape(*data.axes[:-1]), moment=0) - data.moment(axis="wa", channel=channel, resultant=wt.kit.joint_shape(*data.axes[:-1]), moment=1) + data.moment(axis=axis, channel=channel, resultant=wt.kit.joint_shape(*data.axes[:-1]), moment=0) + data.moment(axis=axis, channel=channel, resultant=wt.kit.joint_shape(*data.axes[:-1]), moment=1) amplitudes = data.channels[-2] centers = data.channels[-1] + data.transform(*[a for a in data.axis_expressions if a != axis]) - # would be nice if this was retained in the file I'm using, but it can be hacked - C1 = curve.dependents['0'] - data.create_variable("w1_Crystal_1", values=C1[:, None, None]) - - data.transform("w1_Crystal_1", "w1_Delay_1", "wa") - # this suprises me... Kyle can you resolve? - assert data.w1_Delay_1.shape == (25, 51, 1) + # TODO: gtol/ltol should maybe be moved to wt + cutoff = amplitudes.max() * gtol + print(cutoff) + amplitudes.clip(min=cutoff) + centers[np.isnan(amplitudes)] = np.nan + #max_axis = tuple(i for i, v in enumerate(data.axes[0].shape) if v > 1) + #cutoff = np.amax(amplitudes[:], axis=1, keepdims=True) * ltol + #amplitudes.clip(min=cutoff) - - - # preapre for plot + # prepare for plot fig, gs = wt.artists.create_figure(width='single', cols=[1, 'cbar']) cmap = wt.artists.colormaps['default'] cmap.set_bad([0.75] * 3, 1.) cmap.set_under([0.75] * 3, 1.) ax = plt.subplot(gs[0, 0]) + cax = plt.subplot(gs[1]) X, Y, Z = wt.artists.pcolor_helper(data.axes[0].points, data.axes[1].points, amplitudes.points) ax.pcolor(X, Y, Z) + wt.artists.plot_colorbar(cax, cmap="viridis")#, vlim=(np.nanmin(Z), np.nanmax(Z))) + + + # would be nice if this was retained in the file I'm using, but it can be hacked + #C1 = curve.dependents['0'] + #data.create_variable("w1_Crystal_1", values=C1[:, None, None]) + #data.transform("w1_Crystal_1", "w1_Delay_1", "wa") + # TODO, make sure array axis isn't counted in "full" (may need separate helper) + points = list(zip(*[a.full.flatten() for a in data.axes])) + ndim = len(data.axes) + delaunay = scipy.spatial.Delaunay(points) + + amp_interp = scipy.interpolate.LinearNDInterpolator(delaunay, amplitudes.full.flatten()) + cen_interp = scipy.interpolate.LinearNDInterpolator(delaunay, centers.full.flatten()) + + out_points = [] + for p in curve.setpoints[:]: + iso_points = [] + for s, pts, vals in find_simplices_containing(delaunay, cen_interp, p): + iso_points.extend(edge_intersections(pts, vals, p)) + iso_points = np.array(iso_points) + if iso_points.size > 5: + ax.scatter(*iso_points.T, s=2, label=f"{p} iso points") + #out_points.append((moment(iso_points.T[i], amp_interp(iso_points)) for i in range(ndim))) + + #ax.scatter(out_points, label="out_points") + #ax.legend() plt.show() +def find_simplices_containing(delaunay, interpolator, point): + for s in delaunay.simplices: + extrema = interpolator([p for p in delaunay.points[s]]) + if min(extrema) < point <= max(extrema): + yield s, delaunay.points[s], extrema +def edge_intersections(points, evaluated, target): + sortord = np.argsort(evaluated) + evaluated = evaluated[sortord] + points = points[sortord] + for (p1, p2), (v1,v2) in zip(itertools.combinations(points,2), itertools.combinations(evaluated, 2)): + if v1 < target <= v2: + yield tuple(p1[i] + (p2[i]-p1[i])*((target - v1)/(v2-v1)) for i in range(len(p1))) def old(): diff --git a/tests/workup/holistic/2018-11-30/bvwnu.py b/tests/workup/holistic/2018-11-30/bvwnu.py index 6bb2cb2..38dc187 100644 --- a/tests/workup/holistic/2018-11-30/bvwnu.py +++ b/tests/workup/holistic/2018-11-30/bvwnu.py @@ -1,3 +1,6 @@ +import matplotlib +matplotlib.use("TkAgg") + import tempfile import attune import WrightTools as wt @@ -17,5 +20,6 @@ ] old = attune.TopasCurve.read(curve_paths, interaction_string='NON-NON-NON-Sig') # do calculation -new = attune.workup.holistic(d, "array_signal", old) +d.transform("w1_Crystal_1", "w1_Delay_1", "wa") +new = attune.workup.holistic(d, "array_signal", [], old, gtol=.000, level=True) # check