Source code for sherpa.models.template

#
#  Copyright (C) 2011, 2016, 2019, 2020, 2021  Smithsonian Astrophysical Observatory
#
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License along
#  with this program; if not, write to the Free Software Foundation, Inc.,
#  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

import operator

import numpy

from sherpa.utils.err import ModelErr
from .parameter import Parameter
from .model import ArithmeticModel, modelCacher1d
from .basic import TableModel

__all__ = ('create_template_model', 'TemplateModel', 'KNNInterpolator',
           'Template')


def create_template_model(modelname, names, parvals, templates,
                          template_interpolator_name='default'):
    """
    Create a TemplateModel model class from template input


    `modelname`  - name of the template model.

    `names`      - list of strings that define the order of the
                   named parameters.

    `parvals`    - 2-D ndarray of parameter vectors, index corresponds
                   to the spectrum in `templates`. The parameter grid.

    `templates`  - list of TableModel objects that contain a spectrum
                   at a specific parameter vector (corresponds to a row
                   in `parvals`).

    `template_interpolator_name` - name of the template interpolator, or None
                   for disabling interpolation *between* templates.
                   See load_template_model for more information.

    """
    # Create a list of parameters from input
    pars = []
    for ii, name in enumerate(names):
        minimum = min(parvals[:, ii])
        maximum = max(parvals[:, ii])
        initial = parvals[:, ii][0]
        # Initial parameter value is always first parameter value listed
        par = Parameter(modelname, name, initial,
                        minimum, maximum,
                        minimum, maximum)
        pars.append(par)

    # Create the templates table from input
    tm = TemplateModel(modelname, pars, parvals, templates)
    if template_interpolator_name is not None:
        if template_interpolator_name in interpolators:
            interp = interpolators[template_interpolator_name]
            args = interp[1]
            args['template_model'] = tm
            args['name'] = modelname
            return interp[0](**args)
    else:
        return tm


[docs]class InterpolatingTemplateModel(ArithmeticModel): def __init__(self, name, template_model): self.template_model = template_model for par in template_model.pars: self.__dict__[par.name] = par self.parvals = template_model.parvals ArithmeticModel.__init__(self, name, template_model.pars)
[docs] def fold(self, data): for template in self.template_model.templates: template.fold(data)
[docs] @modelCacher1d def calc(self, p, x0, x1=None, *args, **kwargs): interpolated_template = self.interpolate(p, x0) return interpolated_template(x0, x1, *args, **kwargs)
[docs]class KNNInterpolator(InterpolatingTemplateModel): def __init__(self, name, template_model, k=None, order=2): self._distances = {} if k is None: self.k = 2*template_model.parvals[0].size else: self.k = k self.order = order InterpolatingTemplateModel.__init__(self, name, template_model) def _calc_distances(self, point): self._distances = {} for i, t_point in enumerate(self.template_model.parvals): self._distances[i] = numpy.linalg.norm(point - t_point, self.order) self._distances = sorted(self._distances.items(), key=operator.itemgetter(1))
[docs] def interpolate(self, point, x_out): self._calc_distances(point) if self._distances[0][1] == 0: return self.template_model.templates[self._distances[0][0]] k_distances = self._distances[:self.k] weights = [(idx, 1/numpy.array(distance)) for idx, distance in k_distances] sum_weights = sum([1/weight for idx, weight in k_distances]) y_out = numpy.zeros(len(x_out)) for idx, weight in weights: y_out += self.template_model.templates[idx].calc((weight,), x_out) y_out /= sum_weights tm = TableModel('interpolated') tm.load(x_out, y_out) return tm
[docs]class Template(KNNInterpolator): def __init__(self, *args, **kwargs): KNNInterpolator.__init__(self, *args, **kwargs)
[docs]class TemplateModel(ArithmeticModel): def __init__(self, name='templatemodel', pars=(), parvals=None, templates=None): self.parvals = parvals if parvals is not None else [] self.templates = templates if templates is not None else [] self.index = {} for par in pars: self.__dict__[par.name] = par for ii, parval in enumerate(parvals): self.index[tuple(parval)] = templates[ii] ArithmeticModel.__init__(self, name, pars) self.is_discrete = True
[docs] def fold(self, data): for template in self.templates: template.fold(data)
[docs] def get_x(self): p = tuple(par.val for par in self.pars) template = self.query(p) return template.get_x()
[docs] def get_y(self): p = tuple(par.val for par in self.pars) template = self.query(p) return template.get_y()
[docs] def query(self, p): try: return self.index[tuple(p)] except KeyError: raise ModelErr("Interpolation of template parameters was disabled for this model, but parameter values not in the template library have been requested. Please use gridsearch method and make sure the sequence option is consistent with the template library")
[docs] @modelCacher1d def calc(self, p, x0, x1=None, *args, **kwargs): table_model = self.query(p) # return interpolated the spectrum according to the input grid # (x0, [x1]) return table_model(x0, x1, *args, **kwargs)
interpolators = { 'default': (Template, {'k': 2, 'order': 2}) }