Source code for traits.interface_checker

# (C) Copyright 2005-2023 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!

""" An attempt at type-safe casting.
"""

from inspect import getfullargspec, getmro
import logging
from types import FunctionType

from .has_traits import HasTraits

logger = logging.getLogger(__name__)


# Constants:

# Message templates for interface errors.
BAD_SIGNATURE = (
    "The '%s' class signature for the '%s' method is different "
    "from that of the '%s' interface."
)
MISSING_METHOD = (
    "The '%s' class does not implement the '%s' method of the "
    "'%s' interface."
)
MISSING_TRAIT = (
    "The '%s' class does not implement the %s trait(s) of the "
    "'%s' interface."
)


[docs]class InterfaceError(Exception): """ The exception raised if a class does not really implement an interface. """ pass
[docs]class InterfaceChecker(HasTraits): """ Checks that interfaces are actually implemented. """
[docs] def check_implements(self, cls, interfaces, error_mode): """ Checks that the class implements the specified interfaces. 'interfaces' can be a single interface or a list of interfaces. """ # If a single interface was specified then turn it into a list: try: iter(interfaces) except TypeError: interfaces = [interfaces] # If the class has traits then check that it implements all traits and # methods on the specified interfaces: if issubclass(cls, HasTraits): for interface in interfaces: if not self._check_has_traits_class( cls, interface, error_mode ): return False # Otherwise, just check that the class implements all methods on the # specified interface: else: for interface in interfaces: if not self._check_non_has_traits_class( cls, interface, error_mode ): return False return True
def _check_has_traits_class(self, cls, interface, error_mode): """ Checks that a 'HasTraits' class implements an interface. """ return self._check_traits( cls, interface, error_mode ) and self._check_methods(cls, interface, error_mode) def _check_non_has_traits_class(self, cls, interface, error_mode): """ Checks that a non-'HasTraits' class implements an interface. """ return self._check_methods(cls, interface, error_mode) def _check_methods(self, cls, interface, error_mode): """ Checks that a class implements the methods on an interface. """ cls_methods = self._get_public_methods(cls) interface_methods = self._get_public_methods(interface) for name in interface_methods: if name not in cls_methods: return self._handle_error( MISSING_METHOD % ( self._class_name(cls), name, self._class_name(interface), ), error_mode, ) # Check that the method signatures are the same: cls_argspec = getfullargspec(cls_methods[name]) interface_argspec = getfullargspec(interface_methods[name]) if cls_argspec != interface_argspec: return self._handle_error( BAD_SIGNATURE % ( self._class_name(cls), name, self._class_name(interface), ), error_mode, ) return True def _check_traits(self, cls, interface, error_mode): """ Checks that a class implements the traits on an interface. """ missing = set(interface.class_traits()).difference( set(cls.class_traits()) ) if len(missing) > 0: return self._handle_error( MISSING_TRAIT % ( self._class_name(cls), repr(list(missing))[1:-1], self._class_name(interface), ), error_mode, ) return True def _get_public_methods(self, cls): """ Returns all public methods on a class. Returns a dictionary containing all public methods keyed by name. """ public_methods = {} for c in getmro(cls): # Stop when we get to 'HasTraits'!: if c is HasTraits: break for name, value in c.__dict__.items(): if (not name.startswith("_")) and ( type(value) is FunctionType ): if name not in public_methods: public_methods[name] = value return public_methods def _class_name(self, cls): return cls.__name__ def _handle_error(self, msg, error_mode): if error_mode > 1: raise InterfaceError(msg) if error_mode == 1: logger.warning(msg) return False
# A default interface checker: checker = InterfaceChecker()
[docs]def check_implements(cls, interfaces, error_mode=0): """ Checks that the class implements the specified interfaces. 'interfaces' can be a single interface or a list of interfaces. """ return checker.check_implements(cls, interfaces, error_mode)