Source code for ncempy.algo.distortion

Module to handle distortions in diffraction patters.

import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize

[docs]def filter_ring(points, center, rminmax): '''Filter points to be in a certain radial distance range from center. Parameters: points (np.ndarray): Candidate points. center (np.ndarray/tuple): Center position. rminmax (tuple): Tuple of min and max radial distance. Returns: (np.ndarray): List of filtered points, two column array. ''' try: # points have to be 2D array with 2 columns assert(isinstance(points, np.ndarray)) assert(points.shape[1] == 2) assert(len(points.shape) == 2) # center can either be tuple or np.array center = np.array(center) center = np.reshape(center, 2) rminmax = np.array(rminmax) rminmax = np.reshape(rminmax, 2) except: raise TypeError('Something wrong with the input!') # calculate radii rs = np.sqrt( np.square(points[:,0]-center[0]) + np.square(points[:,1]-center[1]) ) # filter by given limits sel = (rs>=rminmax[0])*(rs<=rminmax[1]) if sel.any(): return points[sel] else: return None
[docs]def points_topolar(points, center): '''Convert points to polar coordinate system. Can be either in pixel or real dim, but should be the same for points and center. Parameters: points (np.ndarray): Positions as two column array. center (np.ndarray/tuple): Origin of the polar coordinate system. Returns: (np.ndarray): Positions in polar coordinate system as two column array (r, theta). ''' try: # points have to be 2D array with 2 columns assert(isinstance(points, np.ndarray)) assert(points.shape[1] == 2) assert(len(points.shape) == 2) # center can either be tuple or np.array center = np.array(center) center = np.reshape(center, 2) except: raise TypeError('Something wrong with the input!') # calculate radii rs = np.sqrt( np.square(points[:,0]-center[0]) + np.square(points[:,1]-center[1]) ) # calculate angle thes = np.arctan2(points[:,1]-center[1], points[:,0]-center[0]) return np.array( [rs, thes] ).transpose()
[docs]def plot_ringpolar(points, dims, show=False): '''Plot points in polar coordinate system. Parameters: points (np.ndarrad): Positions in polar coords. dims (tuple): Dimension information to plot labels. show (bool): Set to directly show plot in interactive mode. Returns: (np.ndarray): Image of the plot. ''' try: # try to convert input to np.ndarray with 2 columns (necessary if only one entry provided) points = np.reshape(np.array(points), (-1,2)) # check if enough dims availabel assert(len(dims)>=2) assert(len(dims[0])==3) except: raise TypeError('Something wrong with the input!') fig = plt.figure() ax = fig.add_subplot(111) # mean value as line ax.axhline(np.mean(points[:,0]), ls='--', c='k') # points ax.plot(points[:,1], points[:,0], 'rx') # labels ax.set_xlabel('theta /[rad]') ax.set_xlim( (-np.pi, np.pi) ) ax.set_ylabel('r /{}'.format(dims[0][2])) if show: # render to array fig.canvas.draw() plot = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') plot = plot.reshape(fig.canvas.get_width_height()[::-1] + (3,)) return plot
[docs]def residuals_center( param, data): '''Residual function for minimizing the deviations from the mean radial distance. Parameters: param (np.ndarray): The center to optimize. data (np.ndarray): The points in x,y coordinates of the original image. Returns: (np.ndarray): Residuals. ''' # manually calculating the radii, as we do not need the thetas rs = np.sqrt( np.square(data[:,0]-param[0]) + np.square(data[:,1]-param[1]) ) return (rs-np.mean(rs))
[docs]def optimize_center(points, center, maxfev=1000, verbose=None): '''Optimize the center by minimizing the sum of square deviations from the mean radial distance. Parameters: points (np.ndarray): The points to which the optimization is done (x,y coords in org image). center (np.ndarray/tuple): Initial center guess. maxfev (int): Max number of iterations forwarded to scipy.optimize.leastsq(). verbose (bool): Set to get verbose output. Returns: (np.ndarray): The optimized center. ''' try: # points have to be 2D array with 2 columns assert(isinstance(points, np.ndarray)) assert(points.shape[1] == 2) assert(len(points.shape) == 2) # center can either be tuple or np.array center = np.array(center) center = np.reshape(center, 2) except: raise TypeError('Something wrong with the input!') # run the optimization popt, flag = scipy.optimize.leastsq( residuals_center, center, args=(points), maxfev=maxfev) if flag not in [1,2,3,4]: print('WARNING: center optimization failed.') if verbose: print('optimized center: ({}, {})'.format(center[0], center[1])) return popt
[docs]def rad_dis( theta, alpha, beta, order=2 ): '''Radial distortion due to ellipticity or higher order distortion. Relative distortion, to be multiplied with radial distance. Parameters: theta (np.ndarray/float): Angles at which to evaluate. alpha (float): Orientation of major axis. beta (float): Strength of distortion (beta = (1-r_min/r_max)/(1+r_min/r_max). order (int): Order of distortion. Returns: (np.ndarray/float): Distortion factor. ''' return (1.-np.square(beta))/np.sqrt(1.+np.square(beta)-2.*beta*np.cos(order*(theta+alpha)))
[docs]def residuals_dis(param, points, ns): '''Residual function for distortions. Parameters: param (np.ndarray): Parameters for distortion. points (np.ndarray): Points to fit to. ns (tuple): List of orders to account for. Returns: (np.ndarray): Residuals. ''' est = param[0]*np.ones(points[:,1].shape) for i in range(len(ns)): est *=rad_dis( points[:,1], param[i*2+1], param[i*2+2], ns[i]) return points[:,0] - est
[docs]def optimize_distortion(points, ns, maxfev=1000, verbose=False): '''Optimize distortions. The orders in the list ns are first fitted subsequently and the result is refined in a final fit simultaneously fitting all orders. Parameters: points (np.ndarray): Points to optimize to (in polar coords). ns (tuple): List of orders to correct for. maxfev (int): Max number of iterations forwarded to scipy.optimize.leastsq(). verbose (bool): Set for verbose output. Returns: (np.ndarray): Optimized parameters according to ns. ''' try: assert(isinstance(points, np.ndarray)) assert(points.shape[1] == 2) # check points to be sufficient for fitting assert(points.shape[0] >= 3) # check orders assert(len(ns)>=1) except: raise TypeError('Something wrong with the input!') # init guess for full fit init_guess = np.ones(len(ns)*2+1) init_guess[0] = np.mean(points[:,0]) # make a temporary copy points_tmp = np.copy(points) if verbose: print('correction for {} order distortions.'.format(ns)) print('starting with subsequent fitting:') # subsequently fit the orders for i in range(len(ns)): # optimize order to points_tmp popt, flag = scipy.optimize.leastsq( residuals_dis, (init_guess[0], 0.1, 0.1), args=(points_tmp, (ns[i],)), maxfev=maxfev) if flag not in [1,2,3,4]: print('WARNING: optimization of distortions failed.') # information if verbose: print('fitted order {}: R={} alpha={} beta={}'.format(ns[i], popt[0], popt[1], popt[2])) # save for full fit init_guess[i*2+1] = popt[1] init_guess[i*2+2] = popt[2] # do correction points_tmp[:,0] /= rad_dis(points_tmp[:,1], popt[1], popt[2], ns[i]) # full fit if verbose: print('starting the full fit:') popt, flag = scipy.optimize.leastsq( residuals_dis, init_guess, args=(points, ns), maxfev=maxfev) if flag not in [1,2,3,4]: print('WARNING: optimization of distortions failed.') if verbose: print('fitted to: R={}'.format(popt[0])) for i in range(len(ns)): print('.. order={}, alpha={}, beta={}'.format(ns[i], popt[i*2+1], popt[i*2+2])) return popt
[docs]def plot_distpolar(points, dims, dists, ns, show=False): '''Plot the results of distortion fitting in polar coordinates. Parameters: points (np.ndarray): Points in polar coords. dims (tuple): Dimensions, necessary to have unit information. dists (np.ndarray): Results of dist fitting, length according to ns. ns (list): List of used orders. show (bool): Set to directly show the plot in interactive mode. Returns: (np.ndarray): Image of the plot. ''' try: # check points assert(isinstance(points, np.ndarray)) assert(points.shape[1] == 2) # check if enough dims availabel assert(len(dims)>=2) assert(len(dims[0])==3) # check orders assert(len(ns)>=1) # check dists assert(dists.shape[0] == len(ns)*2+1) except: raise TypeError('Something wrong with the input!') fig = plt.figure() ax = fig.add_subplot(111) # stuff from the single orders ax.axhline(dists[0], ls='--', c='k') xpl_ell = np.linspace(-np.pi, np.pi, 100) for i in range(len(ns)): plt.plot( xpl_ell, dists[0]*rad_dis(xpl_ell, dists[i*2+1], dists[i*2+2], ns[i]), 'm--') # points before ax.plot(points[:,1], points[:,0], 'rx') # sum of all distorts sum_dists = np.ones(xpl_ell.shape)*dists[0] for i in range(len(ns)): sum_dists *= rad_dis(xpl_ell, dists[i*2+1], dists[i*2+2], ns[i]) plt.plot( xpl_ell, sum_dists, 'b-' ) # points after points_corr = np.copy(points) for i in range(len(ns)): points_corr[:,0] /= rad_dis(points[:,1], dists[i*2+1], dists[i*2+2], ns[i]) plt.plot( points_corr[:,1], points_corr[:,0], 'gx') # labels ax.set_xlabel('theta /[rad]') ax.set_xlim( (-np.pi, np.pi) ) ax.set_ylabel('r /{}'.format(dims[0][2])) if show: # render to array fig.canvas.draw() plot = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') plot = plot.reshape(fig.canvas.get_width_height()[::-1] + (3,)) return plot