Skip to content

Commit

Permalink
Debugging after testing at s26.
Browse files Browse the repository at this point in the history
  • Loading branch information
saugatkandel committed Jul 1, 2022
1 parent 6f97279 commit 0a31792
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@
#
import fnmatch
import logging
import shutil
from pathlib import Path

import numpy as np
from paramiko import SFTPClient

from sladsnet.code.base import ExperimentalSample
from sladsnet.code.erd import SladsSklearnModel
from sladsnet.code.measurement_interface import ExternalMeasurementInterface
from sladsnet.input_params import ERDInputParams, GeneralInputParams, SampleParams

import numpy as np
from paramiko import SFTPClient


def load_idxs_and_intensities(data_dir: str):
"""Directory containing the npz with inital positions and intensities.
Expand Down Expand Up @@ -49,36 +48,44 @@ def clean_data_directory(data_dir: str):
"""This is dangerous. Use with care."""
dpath = Path(data_dir)
if dpath.exists():
logging.warning('Removing the files currently present in %s.' % data_dir)
shutil.rmtree(data_dir)
dpath.mkdir()
logging.warning('Removing the npz files currently present in %s.' % data_dir)

npz_files = dpath.glob('*.npz')
for npz in npz_files:
npz.unlink()
else:
dpath.mkdir()


def get_wildcard_files_remote(sftp: SFTPClient, remote_dir: str, search: str):
matching_filenames = []
remote_dir = str(remote_dir)
logging.info('Getting files from %s'%remote_dir)
for filename in sftp.listdir(remote_dir):
if fnmatch.fnmatch(filename, search):
matching_filenames.append(filename)
matching_filenames.append(str(Path(remote_dir) / filename))
return matching_filenames


def get_init_npzs_from_remote(sftp: SFTPClient, remote_dir: str, data_dir: str):
init_npzs = get_wildcard_files_remote(sftp, remote_dir, 'init*.npz')
for f in init_npzs:
sftp.get(f, data_dir)
local_fname = str(Path(data_dir) / Path(f).name)
logging.info('Copying %s to %s'%(f, local_fname))
sftp.get(f, local_fname)


def create_experiment_sample(numx: int, numy: int,
initial_idxs: list,
inner_batch_size: int = 100,
stop_ratio: float = 0.35,
c_value: float = 2.0,
c_value: int = 2,
full_erd_recalculation_frequency: int = 1,
affected_neighbors_window_min: int = 5,
affected_neighbors_window_max: int = 15,
erd_model_file_path: str = None):
if erd_model_file_path is None:
erd_model_file_path = Path(__file__).parent.parent / 'ResultsAndData/TrainingData/cameraman/' \
erd_model_file_path = Path(__file__).parent.parent.parent / 'ResultsAndData/TrainingData/cameraman/' \
/ f'c_{c_value}/erd_model_relu.pkl'
erd_model = SladsSklearnModel(load_path=erd_model_file_path)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@
import tkinter as tk
from pathlib import Path

import epics
import numpy as np
from paramiko import SFTPClient

import helper
from sladsnet.code.base import ExperimentalSample
from sladsnet.code.sampling import check_stopping_criteria
from sladsnet.utils import logger
from sladsnet.utils.readMDA import readMDA
from sladsnet.utils.sftp import setup_sftp, sftp_put, sftp_get

import epics
import numpy as np
from paramiko import SFTPClient

import helper


logger.setup_logging()

REMOTE_PATH = Path('/home/sector26/2022R2/20220621/Analysis/')
Expand Down Expand Up @@ -121,8 +123,14 @@ def monitor_function(self, value=None, **kw):

mda = readMDA('mda_current.mda', verbose=False)
data = np.array(mda[1].d[3].data)
xx = np.round((np.array(mda[1].d[32].data) - self.xpos) / self.scan_stepsize, 0) + self.scan_centerx
yy = np.round((np.array(mda[1].d[31].data) - self.ypos) / self.scan_stepsize, 0) + self.scan_centery
# 32 for attoz
xx = np.round((np.array(mda[1].d[35].data) - self.xpos) / self.scan_stepsize / self.xfactor, 0) \
+ self.scan_centerx
# 31 for samy
yy = np.round((np.array(mda[1].d[36].data) - self.ypos) / self.scan_stepsize, 0) + self.scan_centery

#xx = np.round((np.array(mda[1].d[32].data) - self.xpos) / self.scan_stepsize / self.xfactor, 0) + self.scan_centerx
#yy = np.round((np.array(mda[1].d[31].data) - self.ypos) / self.scan_stepsize, 0) + self.scan_centery
curr_pt = mda[1].curr_pt
points_of_interest = curr_pt % self.points_per_route
if points_of_interest == 0:
Expand All @@ -133,8 +141,8 @@ def monitor_function(self, value=None, **kw):

if curr_pt != expected_shape:
if self.debug:
logging.info("Warning: At iteration", self.current_file_position,
"data shape is", mda[1].curr_pt, ", but the expected shape is", expected_shape)
logging.warning("At iteration", self.current_file_position,
"data shape is", mda[1].curr_pt, ", but the expected shape is", expected_shape)

# This is for the case when we receive a faulty epics trigger without ives actually having
# the new position indices.
Expand All @@ -150,8 +158,8 @@ def monitor_function(self, value=None, **kw):
new_intensities = data[curr_pt - points_of_interest + 1:curr_pt]

if np.shape(route_idxs)[0] != np.shape(new_intensities)[0]:
logging.info('Mismatch between shapes of route %d and ' % xpoints.shape[0] + 'and intensities %d.' % (
np.shape(new_intensities)[0]))
logging.warning('Mismatch between shapes of route %d and ' % xpoints.shape[0] \
+ 'and intensities %d.' % (np.shape(new_intensities)[0]))

if curr_pt == self.store_file_scan_points_num:
if self.completed_run_flag:
Expand Down Expand Up @@ -242,27 +250,24 @@ def write_route_file_and_update_suffix(self):

if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument('m', help='Current mda file offset.')
argparser.add_argument('-c', default='config.txt',
argparser.add_argument('m', help='Current mda file offset.', type=int)
argparser.add_argument('-c', default='config.ini',
help='Name of config file to download from remote.')
argparser.add_argument('-r', '--stop_ratio', type=float, default=0.35,
argparser.add_argument('-r', '--stop_ratio', type=float, default=0.65,
help='Stop ratio.')
argparser.add_argument('-g', '--indices_to_generate', type=int, default=100,
help='Number of position indices to generate from slads at every measurement step.')
argparser.add_argument('-p', '--points_to_scan', type=int, default=50,
help='Number of points to actual scan at every measurement step.')
argparser.add_argument('-t', '--points_to_store', type=int, default=500,
help='Number of points to store in every instructions file.')
sysargs = argparser.parse_args()

sysargs_str = str(sysargs)
logging.info('Current input arguments are %s'%sysargs_str)

if 2 * sysargs.points_to_scan < sysargs.indices_to_generate:
logging.error('Number of points generated by SLADS should be at least 2x the number of points ' \
'scanned at every step.')
sys.exit()
if sysargs.points_to_store <= sysargs.points_to_scan:
logging.error('Number of points stored in every instructions file should be greater than the number of points' \
'scanned at every step.')
sys.exit()

config_fname = sysargs.c
ssh, sftp = setup_sftp(REMOTE_IP, REMOTE_USERNAME)
Expand All @@ -274,35 +279,47 @@ def write_route_file_and_update_suffix(self):
cparser.read(config_fname)
cargs = cparser.defaults()

celems = ['scan_sizex', 'scan_sizey', 'xpos', 'ypos', 'xfactor', 'scan_stepsize']
celems = ['scan_sizex', 'scan_sizey', 'xpos', 'ypos', 'xfactor', 'scan_stepsize', 'current_file_suffix',
'scan_npts']
if set(celems).isdisjoint(cargs):
logging.error('Config file does not contain required input parameters.')

sample = helper.create_experiment_sample(numx=celems['scan_sizex'], numy=celems['scan_sizey'],
inner_batch_size=sysargs.indices_to_generate,
stop_ratio=sysargs.stop_ratio,
c_value=2.0,
full_erd_recalculation_frequency=1,
affected_neighbors_window_min=5,
affected_neighbors_window_max=15)
if int(cargs['scan_npts']) <= sysargs.points_to_scan:
logging.error('Number of points stored in every instructions file should be greater than the number of points' \
'scanned at every step.')
sys.exit()

cargs_str = str(cargs)
logging.info('Config parameters are %s'%cargs_str)

init_data_dir = 'initial_data'
helper.clean_data_directory(init_data_dir)
helper.get_init_npzs_from_remote(sftp=sftp, remote_dir=REMOTE_PATH, data_dir=init_data_dir)
n_init, initial_idxs, initial_intensities = helper.load_idxs_and_intensities(init_data_dir)

logging.info('Downloaded %d init files from remote %s.'%(n_init, REMOTE_PATH))

sample = helper.create_experiment_sample(numx=int(cargs['scan_sizex']), numy=int(cargs['scan_sizey']),
inner_batch_size=sysargs.indices_to_generate,
initial_idxs=initial_idxs,
stop_ratio=sysargs.stop_ratio,
c_value=2,
full_erd_recalculation_frequency=1,
affected_neighbors_window_min=5,
affected_neighbors_window_max=15)

root = tk.Tk()
gui = MainWindow(master=root, sample=sample,
scan_sizex=celems['scan_sizex'],
scan_sizey=celems['scan_sizey'],
xpos=celems['xpos'],
ypos=celems['ypos'],
xfactor=celems['xfactor'],
scan_stepsize=celems['scan_stepsize'],
current_file_suffix=n_init,
scan_sizex=int(cargs['scan_sizex']),
scan_sizey=int(cargs['scan_sizey']),
xpos=float(cargs['xpos']),
ypos=float(cargs['ypos']),
xfactor=float(cargs['xfactor']),
scan_stepsize=float(cargs['scan_stepsize']),
current_file_suffix=int(cargs['current_file_suffix']),
mda_file_offset=sysargs.m,
sftp=sftp,
store_file_scan_points_num=sysargs.t,
store_file_scan_points_num=int(cargs['scan_npts']),
points_per_route=sysargs.points_to_scan)

gui.update(initial_idxs, initial_intensities)
Expand Down
8 changes: 4 additions & 4 deletions sladsnet/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@


def setup_logging(out_dir: str = 'LOGS',
out_prefix:str ='smart_scan',
level='INFO'):
formatter = logging.Formatter("%(asctime)s;%(message)s",datefmt="%Y-%m-%d %H:%M:%S")
out_prefix: str = 'smart_scan',
level='INFO'):
formatter = logging.Formatter("%(asctime)s; %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
ch = logging.StreamHandler()
ch.setFormatter(formatter)

Expand All @@ -15,4 +15,4 @@ def setup_logging(out_dir: str = 'LOGS',
dt = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
fh = logging.FileHandler(out_dir / f'{out_prefix}_{dt}.log')
fh.setFormatter(formatter)
logging.basicConfig(handlers=[ch, fh], level=level)#formatter=formatter, level=level)
logging.basicConfig(handlers=[ch, fh], level=level)

0 comments on commit 0a31792

Please sign in to comment.