########################################################################
# $Header: /var/local/cvsroot/4Suite/Ft/Lib/TestSuite/TestModule.py,v 1.5 2002/08/07 18:07:12 molson Exp $
"""
Provides the TestModule class for wrapping modules/packages.

Copyright 2002 Fourthought, Inc. (USA).
Detailed license and copyright information: http://4suite.org/COPYRIGHT
Project home, documentation, distributions: http://4suite.org/
"""

__revision__ = "$Id: TestModule.py,v 1.5 2002/08/07 18:07:12 molson Exp $"

import os, imp
import TestObject, TestFunction, TestMode, TestCoverage

def GetModuleName(filename):
    modname = os.path.basename(filename)
    if os.path.isdir(filename):
        # Try the directory as a package
        for suffix, mode, type in imp.get_suffixes():
            file = '__init__' + suffix
            if os.path.isfile(os.path.join(filename, file)):
                return modname
    else:
        modname, extension = os.path.splitext(modname)
        for suffix, mode, type in imp.get_suffixes():
            if extension == suffix:
                return modname
    return None
    

class TestModule(TestObject.TestObject):
    """Test object for a module or package."""

    def __init__(self, name, module, addModes, skipModes, allModes):
        TestObject.TestObject.__init__(self, name)
        self.module = module
        self.path = getattr(module, '__name__', None)
        self.addModes = addModes
        self.skipModes = skipModes
        self.allModes = allModes

        # Get all available modes
        self.modes = self.getModes(addModes, skipModes, allModes)
        self.tests = []
        return

    def loadTest(self, name):
        # 'name' is relative to this module
        if self.module:
            module_name = self.module.__name__ + '.' + name
        else:
            module_name = name
        module = __import__(module_name, {}, {}, '*')
        return TestModule(name, module, self.addModes, self.skipModes,
                          self.allModes)

    def addTest(self, name):
        test = self.loadTest(name)
        self.tests.append(test)
        return test

    def getModes(self, addModes, skipModes, allModes):
        # Create the list of modes we will run
        modes = getattr(self.module, 'MODES', [TestMode.DefaultMode()])
        run_modes = []
        if allModes:
            # Use whatever modes are not skipped
            for mode in modes:
                if mode.name not in skipModes:
                    run_modes.append(mode)
        else:
            # Use the specified modes that are not also skipped
            for mode in modes:
                if mode.name in addModes and mode.name not in skipModes:
                    run_modes.append(mode)

            # If no specified modes found, use the default
            if not run_modes:
                for mode in modes:
                    if mode.default and mode.name not in skipModes:
                        run_modes.append(mode)
        return run_modes

    def getTests(self):
        """
        Get the test objects contained within this module.
        """

        # If there are no cached results, gather the sub-tests based on
        # the type of module.
        if not self.tests:

            # Get the test function(s) defined in this module
            for name in dir(self.module):
                if name == 'Test': #name.startswith('Test'):
                    obj = getattr(self.module, name)
                    if callable(obj):
                        self.tests.append(TestFunction.TestFunction(obj))

            # If this is a package, get the available modules
            if hasattr(self.module, '__path__'):
                files = []
                dirs = []

                package_dir = self.module.__path__[0]

                # Ignore the package file itself
                modules = {GetModuleName(self.module.__file__) : 1}
                
                for name in os.listdir(package_dir):
                    path = os.path.join(package_dir, name)
                    name = GetModuleName(path)
                    if name and not modules.has_key(name):
                        # A valid python package/module
                        modules[name] = 1
                        if os.path.isdir(path):
                            dirs.append(name)
                        else:
                            files.append(name)

                # Default running order is alphabetical
                dirs.sort()
                files.sort()

                # Let the module manipulate the test lists
                if hasattr(self.module, 'PreprocessFiles'):
                    (dirs, files) = self.module.PreprocessFiles(dirs, files)

                # Add the test lists to our available tests
                for name in dirs + files:
                    self.addTest(name)

                # If this modules defines a CoverageModule, add the coverage
                # start and end functions.
                if hasattr(self.module, 'CoverageModule'):
                    ignored = None
                    if hasattr(self.module,'CoverageIgnored'):
                        ignored = self.module.CoverageIgnored
                    ct = TestCoverage.TestCoverage(self.module.CoverageModule,ignored)
                    self.tests.insert(0, TestFunction.TestFunction(ct._start))
                    self.tests.append(TestFunction.TestFunction(ct._end))

        return self.tests

    def showTests(self, indent):
        if hasattr(self.module, '__path__'):
            # A package
            print '%s%s%s' % (indent, self.name, os.sep)
            new_indent = indent + ' '*2
            for test in self.getTests():
                test.showTests(new_indent)
        else:
            # A simple module
            print '%s%s' % (indent, self.name)
        return

    def run(self, tester):
        # Determine the modes
        tester.startGroup(self.name)

        modes = []
        for mode in self.modes:
            if mode.initialize(tester):
                modes.append(mode)

        if not modes:
            tester.warning("All modes have been skipped")

        for mode in modes:
            mode.start(tester)
            try:
                have_run = 0
                for test in self.getTests():
                    self.runTest(tester, test)
                    have_run = 1
                if not have_run:
                    tester.warning('Module does define any tests')
            finally:
                mode.finish(tester)

        tester.groupDone()
        return

    def runTest(self, tester, testObject):
        # Saved to check for misbehaving tests
        depth = len(tester.groups)

        # Run the test
        try:
            testObject.run(tester)
        except (KeyboardInterrupt, SystemExit):
            raise
        except:
            tester.exception('Unhandled exception in test')
            # Clean up for the interrupted test
            if tester.test:
                tester.testDone()
            while len(tester.groups) > depth:
                tester.groupDone()
            return

        if tester.test:
            tester.warning('Failed to finish test (fixed)')
            tester.testDone()

        # Verify proper group count
        count = len(tester.groups) - depth
        if count < 0:
            tester.error('Closed too many groups')
        elif count > 0:
            tester.warning('Failed to close %d groups (fixed)' % count)
            while count:
                count -= 1
                tester.message('Closing group %s' % tester.groups[-1])
                tester.groupDone()
        return
