Skip to content

Commit

Permalink
Adding Holo reconstruction
Browse files Browse the repository at this point in the history
  • Loading branch information
beniroquai committed Mar 13, 2024
1 parent 8680c1b commit 4cb6998
Show file tree
Hide file tree
Showing 8 changed files with 424 additions and 147 deletions.
164 changes: 118 additions & 46 deletions imswitch/imcontrol/controller/controllers/HoloController.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from imswitch.imcontrol.view import guitools
from imswitch.imcommon.model import initLogger
from ..basecontrollers import LiveUpdatedController


import threading
import time
class HoloController(LiveUpdatedController):
""" Linked to HoloWidget."""

Expand All @@ -33,16 +33,17 @@ def __init__(self, *args, **kwargs):
self.mWavelength = 488*1e-9
self.NA=.3
self.k0 = 2*np.pi/(self.mWavelength)
self.lastProcessedTime = time.time()

if not isNIP:
return
self.PSFpara = nip.PSF_PARAMS()
self.PSFpara.wavelength = self.mWavelength
self.PSFpara.NA=self.NA
self.PSFpara.pixelsize = self.pixelsize

self.availableReconstructionModes = ["off", "inline", "offaxis"]
self.reconstructionMode = self.availableReconstructionModes[0]
self.dz = 40*1e-3

# Prepare image computation worker
self.imageComputationWorker = self.HoloImageComputationWorker()
self.imageComputationWorker.set_pixelsize(self.pixelsize)
Expand All @@ -52,21 +53,37 @@ def __init__(self, *args, **kwargs):

self.imageComputationThread = Thread()
self.imageComputationWorker.moveToThread(self.imageComputationThread)
self.sigImageReceived.connect(self.imageComputationWorker.computeHoloImage)
#self.sigImageReceived.connect(self.imageComputationWorker.computeHoloImage)
self.imageComputationThread.start()

# Connect CommunicationChannel signals
self._commChannel.sigUpdateImage.connect(self.update)

# Connect HoloWidget signals
self._widget.sigShowToggled.connect(self.setShowHolo)
# Connect HoloWidget signals
self._widget.sigShowInLineToggled.connect(self.setShowInLineHolo)
self._widget.sigShowOffAxisToggled.connect(self.setShowOffAxisHolo)
self._widget.sigUpdateRateChanged.connect(self.changeRate)
self._widget.sigSliderValueChanged.connect(self.valueChanged)

self._widget.sigInLineSliderValueChanged.connect(self.inLineValueChanged)
self._widget.sigOffAxisSliderValueChanged.connect(self.offAxisValueChanged)
self._widget.btnSelectCCCenter.clicked.connect(self.selectCCCenter)
self.changeRate(self._widget.getUpdateRate())
self.setShowHolo(self._widget.getShowHoloChecked())

def valueChanged(self, magnitude):
self.setShowInLineHolo(self._widget.getShowInLineHoloChecked())
self.setShowOffAxisHolo(self._widget.getShowOffAxisHoloChecked())

def selectCCCenter(self):
# get the center of the cross correlation
self.CCCenter = self._widget.getCCCenterFromNapari()
self.CCRadius = self._widget.getCCRadius()
if self.CCRadius is None or self.CCRadius<50:
self.CCRadius = 100
self.imageComputationWorker.set_CCCenter(self.CCCenter)
self.imageComputationWorker.set_CCRadius(self.CCRadius)

def offAxisValueChanged(self, magnitude):
self.dz = magnitude*1e-3
self.imageComputationWorker.set_dz(self.dz)

def inLineValueChanged(self, magnitude):
""" Change magnitude. """
self.dz = magnitude*1e-3
self.imageComputationWorker.set_dz(self.dz)
Expand All @@ -77,81 +94,136 @@ def __del__(self):
if hasattr(super(), '__del__'):
super().__del__()

def setShowHolo(self, enabled):
def setShowInLineHolo(self, enabled):
""" Show or hide Holo. """
self.pixelsize = self._widget.getPixelSize()
self.mWavelength = self._widget.getWvl()
self.NA = self._widget.getNA()
self.k0 = 2 * np.pi / (self.mWavelength)
self.active = enabled
self.init = False

def update(self, detectorName, im, init, isCurrentDetector):
self.reconstructionMode = self.availableReconstructionModes[1]
self.imageComputationWorker.setReconstructionMode(self.reconstructionMode)
self.imageComputationWorker.setActive(enabled)

def setShowOffAxisHolo(self, enabled):
""" Show or hide Holo. """
self.pixelsize = self._widget.getPixelSize()
self.mWavelength = self._widget.getWvl()
self.NA = self._widget.getNA()
self.k0 = 2 * np.pi / (self.mWavelength)
self.active = enabled
self.init = False
self.reconstructionMode = self.availableReconstructionModes[2]
self.imageComputationWorker.setReconstructionMode(self.reconstructionMode)
self.imageComputationWorker.setActive(enabled)
self._widget.createPointsLayer()
#detectorName, image, init, scale, detectorName==self._currentDetectorName

def update(self, detectorName, im, init, scale, isCurrentDetector):
""" Update with new detector frame. """
if not isCurrentDetector or not self.active and not isNIP:

if not self.active or not isNIP:# or not isCurrentDetector:
return

if time.time()-self.lastProcessedTime<1/self.updateRate:
return
if self.it == self.updateRate:
self.it = 0
self.imageComputationWorker.prepareForNewImage(im)
self.sigImageReceived.emit()
self.lastProcessedTime = time.time()
else:
self.it += 1

def displayImage(self, im):
def displayImage(self, im, name):
""" Displays the image in the view. """
self._widget.setImage(im)
if im.dtype=="complex":
self._widget.setImage(np.abs(im), name+"_abs")
self._widget.setImage(np.angle(im), name+"_angle")
else:
self._widget.setImage(np.abs(im), name)

def changeRate(self, updateRate):
""" Change update rate. """
if updateRate == "":
return
if updateRate <= 0:
updateRate = 1
self.updateRate = updateRate
self.it = 0

class HoloImageComputationWorker(Worker):
sigHoloImageComputed = Signal(np.ndarray)
sigHoloImageComputed = Signal(np.ndarray, str)

def __init__(self):
super().__init__()

self._logger = initLogger(self, tryInheritParent=False)
self._numQueuedImages = 0
self._numQueuedImagesMutex = Mutex()
self.PSFpara = None
self.pixelsize = 1
self.dz = 1

self.reconstructionMode = "off"
self.active = False
self.CCCenter = None
self.CCRadius = 100
self.isBusy = False

def set_CCCenter(self, CCCenter):
self.CCCenter = CCCenter

def set_CCRadius(self, CCRadius):
self.CCRadius = CCRadius

def setActive(self, active):
self.active = active

def setReconstructionMode(self, mode):
self.reconstructionMode = mode

def reconholo(self, image, PSFpara, N_subroi=1024, pixelsize=1e-3, dz=50e-3):
mimage = nip.image(np.sqrt(image))
mimage = nip.extract(mimage, [N_subroi,N_subroi])
mimage.pixelsize=(pixelsize, pixelsize)
mpupil = nip.ft(mimage)
#nip.__make_propagator__(mpupil, PSFpara, doDampPupil=True, shape=mpupil.shape, distZ=dz)
cos_alpha, sin_alpha = nip.cosSinAlpha(mimage, PSFpara)
defocus = self.dz # defocus factor
PhaseMap = nip.defocusPhase(cos_alpha, defocus, PSFpara)
propagated = nip.ft2d((np.exp(1j * PhaseMap))*mpupil)
return np.squeeze(propagated)

def computeHoloImage(self):
if self.reconstructionMode == "inline":
mimage = nip.image(np.sqrt(image.copy()))
mimage = nip.extract(mimage, [N_subroi,N_subroi])
mimage.pixelsize=(pixelsize, pixelsize)
mpupil = nip.ft(mimage)
#nip.__make_propagator__(mpupil, PSFpara, doDampPupil=True, shape=mpupil.shape, distZ=dz)
cos_alpha, sin_alpha = nip.cosSinAlpha(mimage, PSFpara)
defocus = self.dz # defocus factor
PhaseMap = nip.defocusPhase(cos_alpha, defocus, PSFpara)
propagated = nip.ft2d((np.exp(1j * PhaseMap))*mpupil)
return np.squeeze(propagated)
elif self.reconstructionMode == "offaxis" and self.CCCenter is not None:
mimage = np.sqrt(nip.image(image.copy())) # get e-field
mpupil = nip.ft(mimage) # bring to FT space
mpupil = nip.extract(mpupil, (int(self.CCCenter[0]), int(self.CCCenter[1])), (int(self.CCRadius),int(self.CCRadius)), checkComplex=False) # cut out CC-term
mimage = nip.ift(mpupil) # bring back to image space
return np.squeeze(mimage) # this is still complex
else:
return np.zeros_like(image)

def computeHoloImage(self, mHologram):
""" Compute Holo of an image. """
self.isBusy = True
try:
if self._numQueuedImages > 1:
return # Skip this frame in order to catch up
holorecon = np.flip(np.abs(self.reconholo(self._image, PSFpara=self.PSFpara, N_subroi=1024, pixelsize=self.pixelsize, dz=self.dz)),1)
holorecon = np.flip(self.reconholo(mHologram, PSFpara=self.PSFpara, N_subroi=1024, pixelsize=self.pixelsize, dz=self.dz),1)

self.sigHoloImageComputed.emit(np.array(holorecon), "Hologram")
if self.reconstructionMode == "offaxis":
mFT = nip.ft2d(mHologram)
self.sigHoloImageComputed.emit(np.array(np.log(1+mFT)), "FFT")
except Exception as e:
self._logger.error(f"Error in computeHoloImage: {e}")
self.isBusy = False

self.sigHoloImageComputed.emit(np.array(holorecon))
finally:
self._numQueuedImagesMutex.lock()
self._numQueuedImages -= 1
self._numQueuedImagesMutex.unlock()

def prepareForNewImage(self, image):
""" Must always be called before the worker receives a new image. """
self._image = image
self._numQueuedImagesMutex.lock()
self._numQueuedImages += 1
self._numQueuedImagesMutex.unlock()
if self.active and not self.isBusy:
self.isBusy = True
mThread = threading.Thread(target=self.computeHoloImage, args=(self._image,))
mThread.start()


def set_dz(self, dz):
self.dz = dz
Expand Down
19 changes: 3 additions & 16 deletions imswitch/imcontrol/controller/controllers/MCTController.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,22 +360,9 @@ def performAutofocus(self):
if self.positioner is not None and self._widget.isAutofocus() and np.mod(self.nImagesTaken, int(autofocusParams['valuePeriod'])) == 0:
self._widget.setMessageGUI("Autofocusing...")
# turn on illuimination
if autofocusParams['illuMethod']=="Illu1":
self.lasers[0].setValue(self.Illu1Value)
self.lasers[0].setEnabled(True)
time.sleep(.05)
elif autofocusParams['illuMethod']=="Illu2":
self.lasers[1].setValue(self.Illu2Value)
self.lasers[1].setEnabled(True)
time.sleep(.05)
elif autofocusParams['illuMethod']=="LED":
if len(self.leds)>0:
self.leds[0].setValue(self.Illu3Value)
self.leds[0].setEnabled(True)
time.sleep(.05)
else:
self.illu.setAll(1, (self.Illu3Value,self.Illu3Value,self.Illu3Value))

self.activeIlluminations[0].setValue(autofocusParams["valueRange"])
self.activeIlluminations[0].setEnabled(True)
time.sleep(self.tWait)
self.doAutofocus(autofocusParams)
self.switchOffIllumination()

Expand Down
35 changes: 21 additions & 14 deletions imswitch/imcontrol/controller/controllers/SettingsController.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ def adjustFrame(self, *, detector=None):

# Adjust frame
params = self.allParams[detector.name]
binning = int(params.binning.value())
try:
binning = int(params.binning.value())
except:
binning = 1
width = params.width.value()
height = params.height.value()
x0 = params.x0.value()
Expand Down Expand Up @@ -371,20 +374,23 @@ def updateFrame(self, *, detector=None):
self.ROIchanged()

else:
if frameMode == 'Full chip':
fullChipShape = detector.fullShape
params.x0.setValue(0)
params.y0.setValue(0)
params.width.setValue(fullChipShape[0])
params.height.setValue(fullChipShape[1])
if frameMode == "":
pass
else:
roiInfo = self._setupInfo.rois[frameMode]
params.x0.setValue(roiInfo.x)
params.y0.setValue(roiInfo.y)
params.width.setValue(roiInfo.w)
params.height.setValue(roiInfo.h)

self.adjustFrame(detector=detector)
if frameMode == 'Full chip':
fullChipShape = detector.fullShape
params.x0.setValue(0)
params.y0.setValue(0)
params.width.setValue(fullChipShape[0])
params.height.setValue(fullChipShape[1])
else:
roiInfo = self._setupInfo.rois[frameMode]
params.x0.setValue(roiInfo.x)
params.y0.setValue(roiInfo.y)
params.width.setValue(roiInfo.w)
params.height.setValue(roiInfo.h)

self.adjustFrame(detector=detector)

self.syncFrameParams(doAdjustFrame=False)

Expand Down Expand Up @@ -552,3 +558,4 @@ def setDetectorGain(self, detectorName: str=None, gain: float=0) -> None:
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

52 changes: 52 additions & 0 deletions imswitch/imcontrol/model/interfaces/tiscamera_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,55 @@ def __init__(self, isRGB=False, mocktype = None, mockstackpath=None):
self.SensorWidth = dummyFrame.shape[0]
self.properties['SensorHeight'] = self.SensorHeight
self.properties['SensorWidth'] = self.SensorWidth

if self.mocktype == "OffAxisHolo":
# Parameters
width, height = self.SensorWidth, self.SensorHeight # Size of the simulation
wavelength = 0.6328e-6 # Wavelength of the light (in meters, example: 632.8nm for He-Ne laser)
k = 2 * np.pi / wavelength # Wave number
angle = np.pi / 10 # Tilt angle of the plane wave

# Create a phase sample (this is where you define your phase object)
# For demonstration, a simple circular phase object
x = np.linspace(-np.pi, np.pi, width)
y = np.linspace(-np.pi, np.pi, height)
X, Y = np.meshgrid(x, y)
mPupil = (X**2 + Y**2)<np.pi
try:
import NanoImagingPack as nip
mSample = nip.readim()
mSample = nip.extract(mSample, (width, height))/255
except:
mSample = (X**2 + Y**2)<50 # sphere with radius 50
phase_sample = np.exp(1j * mSample)

# Simulate the tilted plane wave
tilt_x = k * np.sin(angle)
tilt_y = k * np.sin(angle) # Change this if you want tilt in another direction
X, Y = np.meshgrid(np.arange(width), np.arange(height))
plane_wave = np.exp(1j * ((tilt_x * X) + (tilt_y * Y)))


# Superpose the phase sample and the tilted plane wave
filtered_phase_sample = np.fft.ifft2(np.fft.fftshift(mPupil) * np.fft.fft2(phase_sample))
#filtered_phase_sample = nip.ift(mPupil * nip.ft(phase_sample))
hologram = filtered_phase_sample + plane_wave
if 0:
import matplotlib.pyplot as plt
plt.imshow(np.angle(hologram))
plt.show()
#plt.imshow(np.angle(np.conjugate(hologram)))
#plt.show()
plt.imshow(np.real(hologram*np.conjugate(hologram)))
plt.show()
plt.imshow(np.log(1+np.abs(nip.ft(hologram))))
plt.show()

#%%
# Calculate the intensity image (interference pattern)
self.holo_intensity_image = np.squeeze(np.real(hologram*np.conjugate(hologram)))




def start_live(self):
Expand Down Expand Up @@ -96,6 +145,9 @@ def grabFrame(self, **kwargs):
# Iterate over the pages in the TIFF file
img = self.tifr.pages[self.iFrame%len(self.tifr.pages)].asarray()
self.iFrame+=1
elif self.mocktype=="OffAxisHolo":
img = self.holo_intensity_image
self.iFrame+=1
elif self.mocktype=="default":
img = np.random.randint(0, 255, (self.SensorHeight, self.SensorWidth)).astype('uint8')
else:
Expand Down
Loading

0 comments on commit 4cb6998

Please sign in to comment.