#!/usr/bin/python

# reconf-inetd - reconfigure and restart inetd
# Copyright (C) 2011, 2012 Serafeim Zanikolas <sez@debian.org>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

import os
import re
import sys
import shutil
import logging
import logging.handlers
import commands
from glob import glob

__author__ = "Serafeim Zanikolas <sez@debian.org>"
__copyright__ = "Copyright (c) 2011-2012 Serafeim Zanikolas"
__license__ = "GPL-2+"
__version__ = "1.120603"

INETD_USER = 'root'
RECONF_INETD_FRAGMENTS_DIR = '/usr/share/reconf-inetd'
SHADOW_FRAGMENTS_DIR = '/var/lib/reconf-inetd'
SERVICES_FILENAME = os.environ.get('SERVICES_FILENAME') or '/etc/services'

class InvalidEntryException(Exception):
    """Exception thrown for invalid entries within an XFragment file."""
    pass

class MissingFieldException(Exception):
    """Exception thrown for xfragment entries with a missing mandatory field."""
    pass

class UnsupportedOperatorException(Exception):
    """Exception thrown for invalid unsupported operators within an XFragment
    entry.

    The only currently supported operator is '='.
    """
    pass

class UnimplementedException(Exception):
    """Exception thrown when a subclass does not implement a method."""
    pass


class BaseService(object):
    """Format-independent service entry."""

    def __init__(self, attrs, logger):
        """Called by subclasses InetdService and XFragment.

        self.service_id and self.hash are not defined until the first call
        to self.__hash__(), because the attributes that are used to construct
        the service_id are not necessarily available during the instantiation
        of a BaseService object."""
        self.attrs = attrs
        self.logger = logger
        self.service_id = None
        self.hash = None
        self.existing_server_path = None

    def get_name(self):
        return self.attrs.get('service')

    def get_protocol(self):
        return self.attrs.get('protocol')

    def get_server(self):
        raise UnimplementedException

    def __hash__(self):
        """Base the object's hash on memoized service_id."""
        if self.service_id is None:
            self._init_service_id()
            self.hash = hash(self.service_id)
        return self.hash

    def _init_service_id(self):
        """Construct a unique string for this service.

        DEP9: An inetd.conf service entry will be considered to be "matching"
        a reconf-inetd fragment when the following fields are equal: service
        name, protocol, and server program.

        See the constructor's docstring on why this code is not there
        instead."""
        name = self.get_name()
        protocol = self.get_protocol()
        server = self.get_server()
        self.service_id = "%s|%s|%s" % (name, protocol, server)

    def __eq__(self, other):
        return self.service_id == other.service_id

    def __ne__(self, other):
        return not self.__eq__(other)

    def get_server_args(self):
        return self.attrs['server_args']

    def has_existing_server_path(self):
        """Return True when server_path is defined and refers to an existing
        file."""
        if self.existing_server_path is None:
            srv_path = self.get_server()
            self.existing_server_path = srv_path and os.path.exists(srv_path)
        return self.existing_server_path


class BaseServiceContainer(object):
    """A wrapper for a set of BaseService objects. This base class is
    inherited by InetdServiceContainer and XFragmentContainer."""

    def __init__(self, logger):
        self.logger = logger
        self.all_services = {}

    def get_all_services(self):
        return self.all_services.values()

    def get_matching_entry(self, srv):
        """returns matching entry or None, if there isn't one"""
        return self.all_services.get(srv)

    def has_matching_entry(self, srv):
        return srv in self.all_services

    def has_matching_enabled_entry(self, srv):
        match = self.get_matching_entry(srv)
        return match is not None and match.is_enabled()

    def has_matching_entry_with_identical_srv_args(self, srv):
        """return True if the container has a matching service entry with
        identical command line server arguments"""
        match = self.get_matching_entry(srv)
        return match is not None and \
               match.get_server_args() == srv.get_server_args()

    def add_service(self, service):
        self.logger.debug('adding service %s' % service.get_name())
        self.all_services[service] = service

    def remove_service(self, service):
        self.logger.debug('removing service %s' % service.get_name())
        del self.all_services[service]


class InetdService(BaseService):
    """A representation of an inetd.conf line."""

    mandatory_attrs = ['service', 'socket_type', 'wait', 'protocol', 'server']

    # MAINT_DISABLED: lines starting with '#<off># ' may be re-enabled
    # USER_DISABLED: lines starting with '# ' are local policy and must not be
    #                touched
    ENABLED, MAINT_DISABLED, USER_DISABLED = 0, 1, 2
    STATUS_STR = ['', '#<off># ', '# ']

    def __init__(self, attrs, lineno, logger):
        """lineno is the line number at which this service is stored in
        inetd.conf"""
        super(InetdService, self).__init__(attrs, logger)
        if self.attrs.get('user') is None:
            self.attrs['user'] = 'nobody'
        self.lineno = lineno
        self.is_valid = all([attrs.get(i) for i in InetdService.mandatory_attrs])
        self.set_status(self.ENABLED)

        self.logger.debug('loaded %s; is_valid: %s' % (attrs['service'],
                                                       self.is_valid))

    def set_status(self, status):
        self.attrs['status'] = status

    def status_str(self):
        return InetdService.STATUS_STR[self.attrs['status']]

    def enable(self):
        self.attrs['status'] = InetdService.ENABLED

    def is_enabled(self):
        return self.attrs['status'] == InetdService.ENABLED

    def is_maint_disabled(self):
        return self.attrs['status'] == InetdService.MAINT_DISABLED

    def is_user_disabled(self):
        return self.attrs['status'] == InetdService.USER_DISABLED

    def __repr__(self):
        """one-line representation of the service as it would appear in
        inetd.conf"""
        srv_args = self.attrs.get('server_args')
        return '%s%s%s' % (self.status_str(),
                ('%(service)s %(socket_type)s %(protocol)s '
                 '%(wait)s %(user)s %(server)s') % self.attrs,
                ' %s' % srv_args if srv_args else '')

    def get_server(self):
        server = self.attrs.get('server')
        if server and server.endswith('/tcpd') or server.endswith('/rpcd'):
            server_args = self.attrs.get('server_args')
            if server_args:
                server = server_args.split()[0]
        return server

class InetdConfChangeSet(object):
    """Encapsulates the logic for finding out which services must be added,
    removed or enabled in inetd.conf, as well as which shadow fragment files
    must be created or removed, given: an inetd.conf file, a set of
    reconf-inetd fragment files, and a set of shadow fragment files."""

    def __init__(self, inetd_serv_container, fragment_container,
                 shadow_container, logger):
        self.inetd_serv_container = inetd_serv_container
        self.fragment_container = fragment_container
        self.shadow_container = shadow_container
        self.logger = logger
        self.to_enable, self.to_add, self.to_remove = None, None, None

    def services_to_enable(self):
        """Memoized list of InetdService objects that represent previously
        maintainer_disabled inetd.conf entries that must be enabled."""
        if self.to_enable is None:
            self.to_enable = \
                [s for s in self.inetd_serv_container.get_all_services()
                 if (s.is_maint_disabled()) and
                   (s.has_existing_server_path()) and
                   (s in self.fragment_container.get_all_services()) and
                   (s in self.shadow_container.get_all_services())]
        return self.to_enable

    def services_to_add(self):
        """Memoized list of XFragment objects that represent services to be
        added to inetd.conf."""
        if self.to_add is None:
            self.to_add = \
                [f for f in self.fragment_container.get_valid_services()
                 if not self.inetd_serv_container.has_matching_entry(f)]
        return self.to_add

    def services_to_remove(self):
        """Memorized set of InetdService objects that represent entries to be
        removed from inetd.conf.

        This will fail if called _after_ services have been added, because
        they've be added as XFragment objects which do not implement
        is_user_disabled()."""
        if self.to_remove is None:
            self.to_remove = \
                [s for s in self.inetd_serv_container.get_all_services()
                 if (not s.is_user_disabled()) and
                    (not s.has_existing_server_path()) and
                    (s not in self.fragment_container.get_all_services()) and
                    (self.shadow_container.has_matching_entry_with_identical_srv_args(s))]
        return self.to_remove

    def is_not_empty(self):
        """return True if there are no changes"""
        return any(self.services_to_enable()) or \
               any(self.services_to_add()) or any(self.services_to_remove())

    def update_shadow_fragments(self, shadow_dir):
        self._create_shadow_fragments(shadow_dir)
        self._remove_shadow_fragments()

    def _create_shadow_fragments(self, shadow_dir):
        """Create a shadow fragment for every service that's added to
        inetd.conf."""
        for srv in self.services_to_add():
            basename = os.path.basename(srv.source_filename)
            shadow_filename = '%s/%s' % (shadow_dir, basename)
            self.logger.debug('creating shadow file %s' % shadow_filename)
            shutil.copy(srv.source_filename, shadow_filename)

    def _remove_shadow_fragments(self):
        """Remove a shadow fragment for every service that is removed from
        inetd.conf."""
        for srv in self.services_to_remove():
            fragment = self.shadow_container.get_matching_entry(srv)
            if fragment is None:
                self.logger.warning(('unable to determine the source '
                                     'filename for "%s"') % str(srv))
                continue
            shadow_filename = fragment.source_filename
            if os.path.exists(shadow_filename):
                self.logger.debug('removing shadow file %s' % shadow_filename)
                os.remove(shadow_filename)
            else:
                self.logger.debug(('not removing shadow file %s '
                                   '(does not exist)') % shadow_filename)

class InetdServiceContainer(BaseServiceContainer):
    """Representation of an inetd.conf file, and related logic for managing
    service entries therein."""

    def __init__(self, logger):
        super(InetdServiceContainer, self).__init__(logger)
        self.inetd_conf_lines = []
        self.removed_services = []

    def load_service_entries(self, inetd_conf_fd):
        InetdConfParser(self.logger, self, inetd_conf_fd)

    def get_enabled_services(self):
        """Return a list of services that are neither maintainer- nor
        user-disabled."""
        return [s for s in self.all_services.values()
                if s['status'] == InetdService.ENABLED]

    def prepare_changes(self, change_set):
        """Apply a given set of changes to self.inetd_conf_lines, a list
        representation of inetd.conf.

        The changes are not effective until the persist method is called.
        Note that service disable is not foreseen according to DEP9."""
        self.logger.debug('preparing changes')
        self._remove_services(change_set.services_to_remove())
        self._enable_services(change_set.services_to_enable())
        self._add_services(change_set.services_to_add())

    def _add_services(self, to_add):
        """Add the given XFragment objects as inetd.conf entries."""
        self.logger.info('adding %d new service(s)' % (len(to_add)))
        for srv in to_add:
            self.inetd_conf_lines.append(srv.to_inetd())
            if srv.lineno is None:
                srv.lineno = len(self.inetd_conf_lines) + 1
            self.logger.info('adding service %s' % srv.service_id)
            self.add_service(srv)

    def _remove_services(self, to_remove):
        """Remove the inetd.conf entries that match the given InetdService
        objects."""
        self.logger.debug('removing %d service(s)' % (len(to_remove)))
        for srv in to_remove:
            self.logger.info('removing service %s (line %d)' %
                    (srv.get_name(), srv.lineno))
            self.inetd_conf_lines[srv.lineno] = ''
            self.logger.info('removing service %s' % srv.service_id)
            self.remove_service(srv)

    def _enable_services(self, to_enable):
        """Enable the inetd.conf entries that match the given InetdService
        objects."""
        self.logger.debug('enabling %d service(s)' % (len(to_enable)))
        for srv in to_enable:
            self.logger.info('enabling service %s (line %d)' %
                    (srv.get_name(), srv.lineno))
            srv.enable()
            self.logger.info('enabling service %s' % srv.service_id)
            self.inetd_conf_lines[srv.lineno] = str(srv)

    def persist(self, filename):
        """Actually store on-disk whatever changes were made."""
        self.logger.debug('storing %s' % filename)
        lines = [self._terminate_line(line) for line in self.inetd_conf_lines]
        with open(filename, 'w') as fd:
            fd.writelines(lines)

    def _terminate_line(self, line):
        """Terminate the given line, unless it's already terminated."""
        if line == '' or line[-1] == '\n':
            return line
        return '%s\n' % line


class InetdConfParser(object):
    space_pat = re.compile("\s+")

    def __init__(self, logger, container, inetd_conf_fd):
        self.logger = logger
        self.container = container
        new_lines = inetd_conf_fd.readlines()
        self.container.inetd_conf_lines.extend(new_lines)
        for i, line in enumerate(new_lines):
            self.logger.debug('line %d: %s' % (i, line))
            self.parse_line(line, i)

    def parse_line(self, line, lineno):
        attrs = {}
        fields = InetdConfParser.space_pat.split(line)
        nfields = len(fields)
        if nfields < 6:
            if line and not line.startswith('#') and not line.isspace():
                self.logger.warn('skipping invalid line %d: %s"\n' %
                        (lineno, line))
            return

        if line.startswith('#<off># '):
            status = InetdService.MAINT_DISABLED
            fields = fields[1:]
        elif line.startswith('# '):
            status = InetdService.USER_DISABLED
            fields = fields[1:]
        elif line.startswith('#'):
            # skip comment line
            self.logger.debug('skipping comment line %d: %s' % (lineno, line))
            return
        else:
            status = InetdService.ENABLED

        try:
            attrs['service'] = fields[0]
            attrs['socket_type'] = fields[1]
            attrs['protocol'] = fields[2] #[,sndbuf=size][,rcvbuf=size]
            attrs['wait'] = fields[3] #wait/nowait[.max]
            attrs['user'] = fields[4] #user[.group] or user[:group]
            attrs['server'] = fields[5]
            if nfields > 6:
                attrs['server_args'] = ' '.join(fields[6:]).strip()
            else:
                attrs['server_args'] = ''
        except IndexError:
            # log invalid entries, but not about the standard inetd.conf header
            if line != "# Internet superserver configuration database\n":
                self.logger.warning('skipping line with invalid number of fields: "%s"' % line)
            return

        srv = InetdService(attrs, lineno, self.logger)
        srv.set_status(status)
        if srv.is_valid:
            self.container.add_service(srv)
        else:
            self.logger.warn('skipping invalid entry in line %d' % lineno)


class XFragment(BaseService):
    """A container for a single xinetd configuration fragment, and associated
    methods."""

    mandatory_keys = {
            'socket_type' : set(['stream', 'dgram', 'raw', 'seqpacket']),
            'wait'        : set(['yes', 'no']) }

    map_per_field = { 'wait' : { 'yes' : 'wait',
                                 'no'  : 'nowait'}}
    def __init__(self, service_name, logger, source_filename):
        attrs = { 'service' : service_name, 'server_args' : '' }
        super(XFragment, self).__init__(attrs, logger)
        self.errors = []
        self.valid = None
        self.enabled = None
        self.lineno = None
        self.source_filename = source_filename

    def add_attr(self, key, val):
        # check for validity of attr and val
        self.attrs[key] = val

    def get_server(self):
        server = self.attrs.get('server')
        flags = self.attrs.get('flags')
        if server and server.endswith('/tcpd') or server.endswith('/rpcd') \
                or (flags and 'NAMEINARGS' in flags.upper()):
            server_args = self.attrs.get('server_args')
            if server_args:
                server = server_args.split()[0]
        return server

    def is_valid(self):
        """Return memoized validation result."""
        if self.valid is None:
            self.validate()
        return self.valid

    def validate(self):
        """Return a (errors, warnings) tuple for missing mandatory fields,
        invalid values, etc. From xinetd.conf(5):
              socket_type       (mandatory)
              wait              (mandatory)
              user              (non-internal services only)
              server            (non-internal services only)
              protocol          (RPC and unlisted services only)
              rpc_version       (RPC services only)
              rpc_number        (unlisted RPC services only)
              port              (unlisted non-RPC services only)
        """
        warnings = []
        for key, valid_values in XFragment.mandatory_keys.iteritems():
            value = self.attrs.get(key)
            if value is None:
                msg = 'missing mandatory key "%s"' % key
                if self.attrs['service']:
                    msg = '%s in fragment for "%s"' % (msg,
                            self.attrs['service'])
                self.errors.append(msg)
            elif value not in valid_values:
                warnings.append('unknown value for key %s: %s\n' % (key, value))

        # protocol key is mandatory only for non-listed services
        if not self.attrs.has_key('protocol'):
            protocol = self.guess_protocol()
            if protocol is None:
                msg = 'missing protocol key for unlisted service %s'
                self.errors.append(msg % self.attrs['service'])
            else:
                self.attrs['protocol'] = protocol

        # user and server keys are mandatory, unless type=internal
        if (not self.attrs.has_key('user') or \
            not self.attrs.has_key('server')):
            if not self.attrs.has_key('type') or \
                   self.attrs['type'].lower() != 'internal':
                self.errors.append(('"user" and "server" keys are mandatory, '
                                    'unless type=internal (which is not the '
                                    'case)'))
            else:
                if not self.attrs.has_key('user'):
                    self.attrs['user'] = INETD_USER
                if not self.attrs.has_key('server'):
                    self.attrs['server'] = 'internal'

        self.valid = self.errors == []
        return [self.errors, warnings]

    def _translate(self, key):
        """Translates a value from an xinetd fragment to the corresponding one
        for /etc/inetd.conf, if applicable (or returns the same value).
        """
        value = self.attrs[key]
        try:
            return XFragment.map_per_field[key][value]
        except KeyError:
            return value

    def to_inetd(self):
        """Return fragment as an inetd service entry."""
        translated_attrs = dict([(k, self._translate(k))
                                 for k in self.attrs.iterkeys()])

        r = ('%(service)s %(socket_type)s %(protocol)s %(wait)s %(user)s '
             '%(server)s') % translated_attrs
        if translated_attrs['server_args']:
            return '%s %s\n' % (r, self.attrs['server_args'])
        return '%s\n' % r

    def create_shadow_fragment(self, shadow_dir):
        shadow_filename = '%s/%s' % (shadow_dir, self.source_filename)
        self.logger.debug('creating shadow file %s' % shadow_filename)
        shutil.copy(self.source_filename, shadow_filename)

    def store(self, filename):
        """Store as a fragment file."""
        with open(filename, 'w') as fd:
            content = ['service %s' % self.get_name(),
                       '{',
                       '\n'.join(['%s = %s' % (k, v)
                                  for k, v in self.attrs.iteritems()]),
                       '}']
            fd.write('\n'.join(content))

    def guess_protocol(self):
        if os.path.exists(SERVICES_FILENAME):
            srv_name = self.get_name()
            for line in open(SERVICES_FILENAME):
                try:
                    fields = line.split()
                    if fields[0] == srv_name:
                        return fields[1].split('/')[1]
                except IndexError:
                    # ignore malformed line
                    continue


class XFragmentParser(object):
    """Parse an extended internet services deamon configuration file, with one
       or more fragments, using a simple FSM. The FSM determines which *_state
       method should be invoked, at any point of the input data."""
    space_pat = re.compile(r'[ \t]+')

    def __init__(self, fragment_container, logger):
        self.logger = logger
        self._fragment = None
        self.container = fragment_container

    def load_files(self, fragment_files):
        """Load a list of xfragment files into the container."""
        for fname in [os.path.abspath(f) for f in fragment_files]:
            # is readable
            if not os.path.isfile(fname):
                self.logger.warn(('%s does not exist or is not a regular'
                    'file\n') % fname)
                continue
            prev_services_count = len(self.container.get_all_services())
            with open(fname, 'r') as fd:
                try:
                    self.current_filename = fname
                    self.parse(fd)
                except InvalidEntryException, exc:
                    self.logger.error('in file %s:\n%s\n' % (fname, str(exc)))
            new_services_count = len(self.container.get_all_services())
            if prev_services_count == new_services_count:
                self.logger.warn("%s does not contain a valid service fragment"
                        % fname)

    def _initial_state(self, tokens):
        if tokens[0] == 'service' and len(tokens) == 2:
            service_name = tokens[1]
            self._fragment = XFragment(service_name, self.logger,
                                       self.current_filename)
            self.parse_func = self._open_brace_state

    def _open_brace_state(self, tokens):
        if tokens[0] == '{' and len(tokens) == 1:
            self.parse_func = self._key_value_state

    def _key_value_state(self, tokens):
        if len(tokens) == 1 and tokens[0] == '}':
            self._close_brace_state()
        elif len(tokens) >= 3:
            key = tokens[0]
            operator = tokens[1]
            value = " ".join(tokens[2:])
            # TODO add support for += and -= operators
            if operator != '=':
                raise UnsupportedOperatorException(operator)
            self._fragment.add_attr(key, value)

    def _close_brace_state(self):
        (errors, warnings) = self._fragment.validate()
        if errors == []:
            self.container.add_service(self._fragment)
        else:
            self.logger.error('errors for %s: %s;\n' % (self.current_filename,
                                                       "".join(errors)))
            raise MissingFieldException('%s;\n' % "\n".join(errors))
        if warnings:
            self.logger.warn('warnings for %s: %s;\n' % (self.current_filename,
                                                         "".join(warnings)))
        self.parse_func = self._initial_state


    def parse(self, input_fd):
        """Parse one or more configuration fragments from the supplied file
        descriptor, usin *_state methods."""
        raw_fragment = input_fd.read()
        # normalise spacing
        lines = XFragmentParser.space_pat.sub(' ', raw_fragment)
        lines = [k.lstrip().rstrip() for k in lines.split('\n')]

        self.parse_func = self._initial_state
        i = None
        for i, line in enumerate(lines):
            if line.startswith("#") or line.isspace() or line == '':
                continue
            tokens = line.split()
            self.parse_func(tokens)
        if self.parse_func != self._initial_state:
            msg = 'fragment does not end with a "}"%s' % \
                    '' if i is None else (' in line %d\n' % i)
            raise InvalidEntryException(msg)


class XFragmentContainer(BaseServiceContainer):
    """Encapsulates the object representation of xfragment files."""

    def __init__(self, logger, fragment_files=None):
        super(XFragmentContainer, self).__init__(logger)
        if fragment_files:
            self.load_files(fragment_files)

    def load_files(self, fragment_files):
        XFragmentParser(self, self.logger).load_files(fragment_files)

    def get_valid_services(self):
        return [srv for srv in self.all_services.values()
                if srv.has_existing_server_path()]

class CmdLineArguments(object):
    """Class to parse, check and encapsulate command-line argument values."""

    def __init__(self):
        from optparse import OptionParser

        usage = \
"""
        %prog [--verbose]
        %prog --sanity-check fragment [... fragment]"""
        parser = OptionParser(usage)
        #parser.add_option('-d', '--dry-run', dest='dry_run', default=False,
        #                  help="""do not actually update inetd.conf, just show
        #                  what would have happened""")
        parser.add_option('-c', '--sanity-check', dest='fragments_to_check',
                          default='', help="""test the validity of the
                          xinetd.conf-like configuration fragments, as
                          specified by a space-separated list of files""")
        parser.add_option('-v', '--verbose', action='store_true',
                          dest='verbose', default=False,
                          help='explain what happens')
        parser.add_option('-V', '--version', dest='show_version',
                          action="store_true", default=False,
                          help='show version and exit')

        (options, args) = parser.parse_args()

        if options.show_version:
            self.display_version_and_exit()

        #self.dry_run = options.dry_run
        self.verbose = options.verbose
        if options.fragments_to_check:
            self.update_mode = False
            self.fragments_to_check = [options.fragments_to_check] + args
        else:
            self.update_mode = True
            args and parser.error("Unknown argument %s" % " ".join(args))

        self.inetd_conf_fname = os.environ.get('INETD_CONF_FILENAME') \
                                or '/etc/inetd.conf'
        self.fragments_dir = os.environ.get('RECONF_INETD_FRAGMENTS_DIR') \
                             or RECONF_INETD_FRAGMENTS_DIR
        self.shadow_fragments_dir = os.environ.get('SHADOW_FRAGMENTS_DIR') \
                             or SHADOW_FRAGMENTS_DIR

    def display_version_and_exit(self):
        sys.stdout.write('reconf-inetd %s\n' % __version__)
        exit(0)

class InetdRestarter(object):

    FAKE_INVOCATION = os.environ.get("UPDATE_INETD_FAKE_IT")

    def __init__(self, change_set, logger):
        self.logger = logger
        if any(change_set.services_to_remove()) and \
           not (any(change_set.services_to_add()) or \
                any(change_set.services_to_enable())):
            self.sysv_action = 'force-reload'
        else:
            self.sysv_action = 'restart'

    def restart(self):
        sysv_script = glob('/etc/init.d/*inetd')
        self.logger.debug("About to %s inetd via invoke-rc.d\n" %
                self.sysv_action)
        if not InetdRestarter.FAKE_INVOCATION and sysv_script != []:
            service = os.path.basename(sysv_script[0])
            cmd = '/usr/sbin/invoke-rc.d %s %s' % (service, self.sysv_action)
            self.logger.info("about to run: %s" % cmd)
            status, output = commands.getstatusoutput(cmd)
            if status != 0:
                self.logger.error('failed to restart %s via invoke-rc.d: %s' %
                        (service, output))


def update_inetd_conf(logger, inetd_conf_fname, fragment_files, shadow_files,
                      shadow_fragments_dir):
    """Update inetd.conf and add/remove shadow fragment files, if
    applicable."""
    inetd_serv_container = InetdServiceContainer(logger)
    with open(inetd_conf_fname) as inetd_conf_fd:
        inetd_serv_container.load_service_entries(inetd_conf_fd)

    fragment_container = XFragmentContainer(logger, fragment_files)
    shadow_container =  XFragmentContainer(logger, shadow_files)

    change_set = InetdConfChangeSet(inetd_serv_container, fragment_container,
                                    shadow_container, logger)

    if change_set.is_not_empty():
        inetd_serv_container.prepare_changes(change_set)
        inetd_serv_container.persist(inetd_conf_fname)
        change_set.update_shadow_fragments(shadow_fragments_dir)
        InetdRestarter(change_set, logger).restart()
    else:
        logger.info('no changes to be made')

def die_unless_root_or_testing():
    # no point in running for real unless we have root privileges
    if os.environ.get('RECONF_INETD_LOG') is None \
            and os.environ.get('INETD_CONF_FILENAME') is None \
            and os.getuid() != 0:
        sys.stderr.write("reconf-inetd in default mode requires root privileges\n")
        exit(0)

def main():
    args = CmdLineArguments()

    logger = logging.getLogger('InetdConfLogger')
    logger.setLevel(logging.DEBUG) # to be overriden by specific handlers

    stdout_loglevel = os.environ.get('RECONF_INETD_LOGLEVEL')
    if stdout_loglevel is not None:
        stdout_loglevel = int(stdout_loglevel)
    elif args.verbose:
        stdout_loglevel = logging.INFO
    else:
        stdout_loglevel = logging.ERROR
    stdout_handler = logging.StreamHandler(sys.stdout)
    logger.addHandler(stdout_handler)
    stdout_handler.setLevel(stdout_loglevel)

    if args.update_mode:
        die_unless_root_or_testing()
        LOG_FILENAME = os.environ.get('RECONF_INETD_LOG') \
                or '/var/log/reconf-inetd.log'
        file_handler = logging.FileHandler(LOG_FILENAME)
        file_handler.setLevel(logging.INFO)
        log_formatter = logging.Formatter("%(asctime)s - %(message)s")
        file_handler.setFormatter(log_formatter)
        logger.addHandler(file_handler)

        if not os.path.exists(args.inetd_conf_fname) or \
           not os.path.exists(args.fragments_dir):
            logger.debug('either or both of %s, %s are missing; nothing to do'
                    % (args.inetd_conf_fname, args.fragments_dir))
            exit(0)
        if not os.path.isfile(args.inetd_conf_fname):
            logger.error('%s is not a regular file\n' %
                    args.inetd_conf_fname)
            exit(0)
        if not os.path.isdir(args.fragments_dir):
            logger.error('%s is not a directory\n' % args.fragments_dir)
            exit(0)
        fragment_files = glob('%s/*' % args.fragments_dir)
        shadow_files = glob('%s/*' % args.shadow_fragments_dir)
        if not fragment_files and not shadow_files:
            logger.debug('no reconf/shadow fragments found; nothing to do')
            exit(0)
        update_inetd_conf(logger, args.inetd_conf_fname, fragment_files,
                          shadow_files, args.shadow_fragments_dir)
    else:
        try:
            XFragmentContainer(logger, args.fragments_to_check)
        except MissingFieldException:
            # just catch it; errors are logged where the exception is thrown
            pass

if __name__ == "__main__":
    try:
        main()
    except Exception:
        import traceback
        sys.stderr.write(('reconf-inetd terminated unexpectedly! Please file '
                          'a bug report against\nthe reconf-inetd package '
                          'with the following info:\n'))
        traceback.print_exc()
    exit(0) # we must always exit successfully
