Skip to content
Snippets Groups Projects
Commit 95e8d0e9 authored by Yurov, Dmitry's avatar Yurov, Dmitry
Browse files

DefaultFitObserver now works for specular fitting

Redmine: #1859

DefaultFitObserver was left almost empty, the main functionality
was moved to Plotter classes (PlotterGISAS and PlotterSpecular)
parent de0109a9
No related branches found
No related tags found
No related merge requests found
...@@ -12,11 +12,12 @@ ...@@ -12,11 +12,12 @@
# @authors J. Fisher, M. Ganeva, G. Pospelov, W. Van Herck, J. Wuttke # @authors J. Fisher, M. Ganeva, G. Pospelov, W. Van Herck, J. Wuttke
''' '''
# ************************************************************************** # # ************************************************************************** #
from __future__ import print_function from __future__ import print_function
import bornagain as ba import bornagain as ba
from bornagain import deg as deg from bornagain import deg as deg
from bornagain import IFitObserver as IFitObserver from bornagain import IFitObserver as IFitObserver
from matplotlib import pyplot as plt
from matplotlib import gridspec, colors
def get_axes_limits(intensity): def get_axes_limits(intensity):
...@@ -31,7 +32,7 @@ def get_axes_limits(intensity): ...@@ -31,7 +32,7 @@ def get_axes_limits(intensity):
# We show radians as degrees. If no units defined in histogram object, # We show radians as degrees. If no units defined in histogram object,
# we assume radians. # we assume radians.
if "rad" in intensity.axesUnits() or len(intensity.axesUnits()) == 0: if "rad" in intensity.axesUnits() or len(intensity.axesUnits()) == 0:
result = [x/deg for x in result] result = [x / deg for x in result]
return result return result
...@@ -77,10 +78,6 @@ def plot_colormap(intensity, zmin=None, zmax=None, ...@@ -77,10 +78,6 @@ def plot_colormap(intensity, zmin=None, zmax=None,
:param zmin: Min value on amplitude's color bar :param zmin: Min value on amplitude's color bar
:param zmax: Max value on amplitude's color bar :param zmax: Max value on amplitude's color bar
""" """
import matplotlib
from matplotlib import pyplot as plt
zmin = 1.0 if zmin is None else zmin zmin = 1.0 if zmin is None else zmin
zmax = intensity.getMaximum() if zmax is None else zmax zmax = intensity.getMaximum() if zmax is None else zmax
...@@ -90,7 +87,7 @@ def plot_colormap(intensity, zmin=None, zmax=None, ...@@ -90,7 +87,7 @@ def plot_colormap(intensity, zmin=None, zmax=None,
im = plt.imshow( im = plt.imshow(
intensity.getArray(), intensity.getArray(),
norm=matplotlib.colors.LogNorm(zmin, zmax), norm=colors.LogNorm(zmin, zmax),
extent=get_axes_limits(intensity), extent=get_axes_limits(intensity),
aspect='auto', aspect='auto',
) )
...@@ -117,36 +114,34 @@ def plot_intensity_data(intensity, zmin=None, zmax=None): ...@@ -117,36 +114,34 @@ def plot_intensity_data(intensity, zmin=None, zmax=None):
:param zmin: Min value on amplitude's color bar :param zmin: Min value on amplitude's color bar
:param zmax: Max value on amplitude's color bar :param zmax: Max value on amplitude's color bar
""" """
from matplotlib import pyplot as plt
plot_colormap(intensity, zmin, zmax) plot_colormap(intensity, zmin, zmax)
plt.show()
class DefaultFitObserver(IFitObserver): class Plotter:
""" def __init__(self):
Draws fit progress every nth iteration. This class has to be attached to
FitSuite via attachObserver method.
FitSuite kernel will call DrawObserver's update() method every n'th iteration.
It is up to the user what to do here.
"""
def __init__(self, draw_every_nth=10): self._fig = plt.figure(figsize=(10.25, 7.69))
IFitObserver.__init__(self, draw_every_nth) self._fig.canvas.draw()
import matplotlib def reset(self):
from matplotlib import pyplot as plt self._fig.clf()
global matplotlib, plt
self.fig = plt.figure(figsize=(10.25, 7.69)) def plot(self, fit_suite):
self.fig.canvas.draw() plt.pause(0.03)
plt.ion()
def make_subplot(self, nplot):
class PlotterGISAS(Plotter):
def __init__(self):
Plotter.__init__(self)
@staticmethod
def make_subplot(nplot):
plt.subplot(2, 2, nplot) plt.subplot(2, 2, nplot)
plt.subplots_adjust(wspace=0.2, hspace=0.2) plt.subplots_adjust(wspace=0.2, hspace=0.2)
def update(self, fit_suite): def plot(self, fit_suite):
self.fig.clf() Plotter.reset(self)
self.make_subplot(1) self.make_subplot(1)
real_data = fit_suite.getRealData() real_data = fit_suite.getRealData()
...@@ -173,10 +168,148 @@ class DefaultFitObserver(IFitObserver): ...@@ -173,10 +168,148 @@ class DefaultFitObserver(IFitObserver):
format(fit_suite.numberOfIterations(), fit_suite.minimizer().minimizerName())) format(fit_suite.numberOfIterations(), fit_suite.minimizer().minimizerName()))
plt.text(0.01, 0.75, "Chi2 " + '{:8.4f}'.format(fit_suite.getChi2())) plt.text(0.01, 0.75, "Chi2 " + '{:8.4f}'.format(fit_suite.getChi2()))
for index, fitPar in enumerate(fit_suite.fitParameters()): for index, fitPar in enumerate(fit_suite.fitParameters()):
plt.text(0.01, 0.55 - index*0.1, '{:30.30s}: {:6.3f}'.format(fitPar.name(), fitPar.value())) plt.text(0.01, 0.55 - index * 0.1, '{:30.30s}: {:6.3f}'.format(fitPar.name(), fitPar.value()))
Plotter.plot(self, fit_suite)
plt.draw() class PlotterSpecular(Plotter):
plt.pause(0.01) """
Draws fit progress every nth iteration. This class has to be attached to
FitSuite via attachObserver method. Intended specifically for observing
specular data fit.
FitSuite kernel will call DrawObserver's update() method every n'th iteration.
"""
def __init__(self, draw_every_nth=10):
Plotter.__init__(self)
self.gs = gridspec.GridSpec(1, 2, width_ratios=[2.5, 1], wspace=0)
@staticmethod
def as_si(val, ndp):
"""
Fancy print of scientific-formatted values
:param val: numeric value
:param ndp: number of decimal digits to print
:return: a string corresponding to the _val_
"""
s = '{x:0.{ndp:d}e}'.format(x=val, ndp=ndp)
m, e = s.split('e')
return r'{m:s}\times 10^{{{e:d}}}'.format(m=m, e=int(e))
@staticmethod
def trunc_str(token, length):
"""
Truncates token if it is longer than length.
Example:
trunc_str("123456789", 8) returns "123456.."
trunc_str("123456789", 9) returns "123456789"
:param token: input string
:param length: max non-truncated length
:return:
"""
return (token[:length - 2] + '..') if len(token) > length else token
def plot_table(self, fit_suite):
# definitions and values
trunc_length = 9 # max string field width in the table
n_digits = 1 # number of decimal digits to print
n_iterations = fit_suite.numberOfIterations() # number of iterations
minimizer = fit_suite.minimizer().minimizerName()
fom_max = fit_suite.getChiSquaredMap().getArray().max() # max Figure Of Merit (FOM) value
fitted_parameters = fit_suite.fitParameters()
# creating table content
labels = ("Parameter", "Value")
table_data = [["Minimizer", '{:s}'.format(self.trunc_str(minimizer, trunc_length))],
["Iteration", '${:d}$'.format(n_iterations)],
["$\chi^2$", '${:s}$'.format(self.as_si(fom_max, n_digits))]]
for fitPar in fitted_parameters:
table_data.append(['{:s}'.format(self.trunc_str(fitPar.name(), trunc_length)),
'${:s}$'.format(self.as_si(fitPar.value(), n_digits))])
# creating table
axs = plt.subplot(self.gs[1])
axs.axis('tight')
axs.axis('off')
table = plt.table(cellText=table_data, colLabels=labels, cellLoc='center',
loc='bottom left', bbox=[0.0, 0.0, 1.0, 1.0])
# # setting alignment to center in 0th column
# for key, cell in table.get_celld().items():
# col = key[1]
# if col == 0:
# cell._loc = 'center' # accessing private member is the only option to change alignment for one cell
def plot_graph(self, fit_suite):
# retrieving data from fit suite
real_data = fit_suite.getRealData()
sim_data = fit_suite.getSimulationData()
# normalizing axis coordinates
axis = real_data.getXaxis().getBinCenters()
norm = 1
if "rad" in real_data.axesUnits():
norm = ba.deg
axis_values = [value / norm for value in axis]
# default font properties dictionary to use
font = {'family': 'serif',
'weight': 'normal',
'size': 16}
plt.subplot(self.gs[0])
plt.semilogy(axis_values, sim_data.getArray(), 'b',
axis_values, real_data.getArray(), 'k--')
plt.ylim((0.5 * real_data.getMinimum(), 5 * real_data.getMaximum()))
plt.legend(['BornAgain', 'Reference'], loc='upper right', prop=font)
plt.xlabel("Incident angle, deg", fontdict=font)
plt.ylabel("Intensity", fontdict=font)
plt.title("Specular data fitting", fontdict=font)
def plot(self, fit_suite):
Plotter.reset(self)
self.plot_graph(fit_suite)
self.plot_table(fit_suite)
plt.tight_layout()
Plotter.plot(self, fit_suite)
class DefaultFitObserver(IFitObserver):
"""
Draws fit progress every nth iteration. This class has to be attached to
FitSuite via attachObserver method.
FitSuite kernel will call DrawObserver's update() method every n'th iteration.
It is up to the user what to do here.
"""
def __init__(self, draw_every_nth=10, SimulationType='GISAS'):
"""
Initializes observer
:param draw_every_nth: specifies when to output data, defaults to each 10th iteration
:param SimulationType: simulation type underlying fitting:
'GISAS' - GISAS simulation, default
'Specular' - specular simulation
"""
IFitObserver.__init__(self, draw_every_nth)
if SimulationType is 'GISAS':
self._plotter = PlotterGISAS()
elif SimulationType is 'Specular':
self._plotter = PlotterSpecular()
else:
exit("Unknown simulation type {:s}.".format(SimulationType))
def update(self, fit_suite):
try:
self._plotter.plot(fit_suite)
except Exception, e:
print(e.message)
if fit_suite.isLastIteration():
plt.ioff()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment