Source code for sherpa.utils.testing

#
#  Copyright (C) 2017  Smithsonian Astrophysical Observatory
#
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License along
#  with this program; if not, write to the Free Software Foundation, Inc.,
#  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

import numpy
import unittest
import os
import importlib

from sherpa.utils._utils import sao_fcmp

try:
    import pytest
    HAS_PYTEST = True
except ImportError:
    HAS_PYTEST = False


def _get_datadir():
    import os
    try:
        import sherpatest
        datadir = os.path.dirname(sherpatest.__file__)
    except ImportError:
        try:
            import sherpa
            datadir = os.path.join(os.path.dirname(sherpa.__file__), os.pardir,
                                   'sherpa-test-data', 'sherpatest')
            if not os.path.exists(datadir) or not os.listdir(datadir):
                # The dir is empty, maybe the submodule was not initialized
                datadir = None
        except ImportError:
            # neither sherpatest nor sherpa can be found, falling back to None
            datadir = None
    return datadir


[docs]class SherpaTestCase(unittest.TestCase): """ Base class for Sherpa unit tests. The use of this class is deprecated in favor of pytest functions. """ # The location of the Sherpa test data (it is optional) datadir = _get_datadir()
[docs] def make_path(self, *segments): """Add the segments onto the test data location. Parameters ---------- *segments Path segments to combine together with the location of the test data. Returns ------- fullpath : None or string The full path to the repository, or None if the data directory is not set. """ if self.datadir is None: return None return os.path.join(self.datadir, *segments)
# What is the benefit of this over numpy.testing.assert_allclose(), # which was added in version 1.5 of NumPy?
[docs] def assertEqualWithinTol(self, first, second, tol=1e-7, msg=None): """Check that the values are equal within an absolute tolerance. Parameters ---------- first : number or array_like The expected value, or values. second : number or array_like The value, or values, to check. If first is an array, then second must be an array of the same size. If first is a scalar then second can be a scalar or an array. tol : number The absolute tolerance used for comparison. msg : string The message to display if the check fails. """ self.assertFalse(numpy.any(sao_fcmp(first, second, tol)), msg)
[docs] def assertNotEqualWithinTol(self, first, second, tol=1e-7, msg=None): """Check that the values are not equal within an absolute tolerance. Parameters ---------- first : number or array_like The expected value, or values. second : number or array_like The value, or values, to check. If first is an array, then second must be an array of the same size. If first is a scalar then second can be a scalar or an array. tol : number The absolute tolerance used for comparison. msg : string The message to display if the check fails. """ self.assertTrue(numpy.all(sao_fcmp(first, second, tol)), msg)
# for running regression tests from sherpa-test-data
[docs] def run_thread(self, name, scriptname='fit.py'): """Run a regression test from the sherpa-test-data submodule. Parameters ---------- name : string The name of the science thread to run (e.g., pha_read, radpro). The name should match the corresponding thread name in the sherpa-test-data submodule. See examples below. scriptname : string The suffix of the test script file name, usually "fit.py." Examples -------- Regression test script file names have the structure "name-scriptname.py." By default, scriptname is set to "fit.py." For example, if one wants to run the regression test "pha_read-fit.py," they would write >>> run_thread("pha_read") If the regression test name is "lev3fft-bar.py," they would do >>> run_thread("lev3fft", scriptname="bar.py") """ scriptname = name + "-" + scriptname self.locals = {} cwd = os.getcwd() os.chdir(self.datadir) try: with open(scriptname, "rb") as fh: cts = fh.read() exec(compile(cts, scriptname, 'exec'), {}, self.locals) finally: os.chdir(cwd)
def has_package_from_list(*packages): """ Returns True if at least one of the ``packages`` args is importable. """ for package in packages: try: importlib.import_module(package) return True except: pass return False if HAS_PYTEST: # Pytest cannot be assumed to be installed by the regular user, unlike unittest, which is part of Python's # standard library. The decorator will be defined if pytest is missing, but if the tests are run they throw # and exception prompting users to install pytest, in those cases where pytest is not installed automatically.
[docs] def requires_data(test_function): """ Decorator for functions requiring external data (i.e. data not distributed with Sherpa itself) is missing. This is used to skip tests that require such data. See PR #391 for why this is a function: https://github.com/sherpa/sherpa/pull/391 """ condition = SherpaTestCase.datadir is None msg = "required test data missing" return pytest.mark.skipif(condition, reason=msg)(test_function)
def requires_package(msg=None, *packages): """ Decorator for test functions requiring specific packages. """ condition = has_package_from_list(*packages) msg = msg or "required module missing among {}.".format( ", ".join(packages)) def decorator(test_function): return pytest.mark.skipif(not condition, reason=msg)(test_function) return decorator
[docs] def requires_plotting(test_function): """ Decorator for test functions requiring a plotting library. """ packages = ('pylab', 'pychips') msg = "plotting backend required" return requires_package(msg, *packages)(test_function)
[docs] def requires_pylab(test_function): """ Returns True if the pylab module is available (pylab). Used to skip tests requiring matplotlib """ packages = ('pylab', ) msg = "matplotlib backend required" return requires_package(msg, *packages)(test_function)
[docs] def requires_fits(test_function): """ Returns True if there is an importable backend for FITS I/O. Used to skip tests requiring fits_io """ packages = ('astropy.io.fits', 'pycrates', ) msg = "FITS backend required" return requires_package(msg, *packages)(test_function)
[docs] def requires_group(test_function): """Decorator for test functions requiring group library""" return requires_package("group library required", 'group')(test_function)
[docs] def requires_stk(test_function): """Decorator for test functions requiring stk library""" return requires_package("stk library required", 'stk')(test_function)
[docs] def requires_ds9(test_function): """Decorator for test functions requiring ds9""" return requires_package('ds9 required', 'sherpa.image.ds9_backend')(test_function)
[docs] def requires_xspec(test_function): return requires_package("xspec required", "sherpa.astro.xspec")(test_function)
else: def wrapped(): raise ImportError(PYTEST_MISSING_MESSAGE) def make_fake(): def wrapper(*args, **kwargs): return wrapped return wrapper requires_data = make_fake() requires_plotting = make_fake() requires_pylab = make_fake() requires_fits = make_fake() requires_group = make_fake() requires_stk = make_fake() requires_ds9 = make_fake() requires_xspec = make_fake()
[docs] def requires_package(*args): return make_fake()
PYTEST_MISSING_MESSAGE = "Package `pytest` is missing. Please install `pytest` before running tests or using the test" \ "decorators"