#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Nagios compatible check for XMPP servers.
# Copyright (C) 2015-2021  Jan Dittberner
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
from __future__ import absolute_import

from datetime import datetime
from select import select
from xml.sax.handler import ContentHandler, feature_namespaces
import logging
import os.path
import socket
import ssl

from defusedxml.sax import make_parser
import nagiosplugin

__author__ = "Jan Dittberner"
__version__ = "0.3.2"


NS_IETF_XMPP_SASL = 'urn:ietf:params:xml:ns:xmpp-sasl'
NS_IETF_XMPP_TLS = 'urn:ietf:params:xml:ns:xmpp-tls'
NS_IETF_XMPP_STREAMS = 'urn:ietf:params:xml:ns:xmpp-streams'
NS_JABBER_CAPS = 'http://jabber.org/protocol/caps'
NS_ETHERX_STREAMS = 'http://etherx.jabber.org/streams'

XMPP_STATE_NEW = 'new'
XMPP_STATE_STREAM_START = 'started'
XMPP_STATE_RECEIVED_FEATURES = 'received features'
XMPP_STATE_FINISHED = 'finished'
XMPP_STATE_ERROR = 'error'
XMPP_STATE_PROCEED_STARTTLS = 'proceed with starttls'

_LOG = logging.getLogger('nagiosplugin')


class XmppException(Exception):
    """
    Custom exception class.

    """
    def __init__(self, message):
        self.message = message
        super(XmppException, self).__init__()


class XmppStreamError(object):
    """
    XMPP stream error class.

    """
    condition = None
    text = None
    other_elements = {}

    def str(self):
        if self.text:
            return "{condition}: {text}".format(
                condition=self.condition, text=self.text)
        return self.condition


class XmppResponseHandler(ContentHandler):
    """
    SAX content handler for XMPP stream content.

    """
    seen_elements = set()
    mechanisms = []
    starttls = False
    tlsrequired = False
    capabilities = {}
    state = XMPP_STATE_NEW
    streaminfo = None

    inelem = []
    level = 0

    def __init__(self, expect_starttls):
        self.expect_starttls = expect_starttls
        super(XmppResponseHandler, self).__init__()

    def startElementNS(self, name, qname, attrs):
        self.inelem.append(name)
        self.seen_elements.add(name)
        if name == (NS_ETHERX_STREAMS, 'stream'):
            self.state = XMPP_STATE_STREAM_START
            self.streaminfo = dict([
                (qname, attrs.getValueByQName(qname)) for
                qname in attrs.getQNames()])
        elif name == (NS_IETF_XMPP_TLS, 'starttls'):
            self.starttls = True
        elif (
            self.inelem[-2] == (NS_IETF_XMPP_TLS, 'starttls') and
            name == (NS_IETF_XMPP_TLS, 'required')
        ):
            self.tlsrequired = True
            _LOG.info("info other side requires TLS")
        elif name == (NS_JABBER_CAPS, 'c'):
            for qname in attrs.getQNames():
                self.capabilities[qname] = attrs.getValueByQName(qname)
        elif name == (NS_ETHERX_STREAMS, 'error'):
            self.state = XMPP_STATE_ERROR
            self.errorinstance = XmppStreamError()
        elif (
            self.state == XMPP_STATE_ERROR and
            name != (NS_IETF_XMPP_STREAMS, 'text')
        ):
            if name[0] == NS_IETF_XMPP_STREAMS:
                self.errorinstance.condition = name[1]
            else:
                self.errorinstance.other_elements[name] = {'attrs': dict([
                    (qname, attrs.getValueByQName(qname)) for
                    qname in attrs.getQNames()
                ])}
        _LOG.debug('start %s (%s)', name, attrs.getQNames())

    def endElementNS(self, name, qname):
        if name == (NS_ETHERX_STREAMS, 'features'):
            self.state = XMPP_STATE_RECEIVED_FEATURES
        elif name == (NS_ETHERX_STREAMS, 'stream'):
            self.state = XMPP_STATE_FINISHED
        elif name == (NS_ETHERX_STREAMS, 'error'):
            raise XmppException("XMPP stream error: {error}".format(
                error=self.errorinstance))
        elif name == (NS_IETF_XMPP_TLS, 'proceed'):
            self.state = XMPP_STATE_PROCEED_STARTTLS
        elif name == (NS_IETF_XMPP_TLS, 'failure'):
            raise XmppException("starttls initiation failed")
        _LOG.debug('end %s', name)
        del self.inelem[-1]

    def characters(self, content):
        elem = self.inelem[-1]
        if elem == (NS_IETF_XMPP_SASL, 'mechanism'):
            self.mechanisms.append(content)
        elif self.state == XMPP_STATE_ERROR:
            if elem == (NS_IETF_XMPP_STREAMS, 'text'):
                self.errorinstance.text = content
            else:
                self.errorinstance.other_elements[elem]['text'] = content
        else:
            _LOG.warning('ignored content in %s: %s', self.inelem, content)

    def is_valid_start(self):
        if not self.state == XMPP_STATE_RECEIVED_FEATURES:
            raise XmppException('XMPP feature list is not finished')
        if self.expect_starttls is True and self.starttls is False:
            raise XmppException('expected STARTTLS capable service')
        if (
            'version' not in self.streaminfo or
            self.streaminfo['version'] != '1.0'
        ):
            _LOG.warning(
                'unknown stream version %s', self.streaminfo['version'])
        return True


class Xmpp(nagiosplugin.Resource):
    """
    Xmpp resource.

    """
    state = nagiosplugin.Unknown
    cause = None
    socket = None
    daysleft = None

    def __init__(
        self, host_address, port, ipv6, is_server, starttls,
        servername, checkcerts, caroots
    ):
        self.address = host_address
        self.port = port
        self.ipv6 = ipv6
        self.is_server = is_server
        self.starttls = starttls
        self.servername = servername
        self.checkcerts = checkcerts
        self.caroots = caroots
        self.make_parser()
        self.set_content_handler()

    def make_parser(self):
        """
        Create and configure a SAX parser instance.

        """
        self.parser = make_parser()
        self.parser.setFeature(feature_namespaces, True)

    def set_content_handler(self):
        """
        Set the XMPP SAX content handler.

        """
        self.contenthandler = XmppResponseHandler(
            expect_starttls=self.starttls)
        self.parser.setContentHandler(self.contenthandler)

    def get_addrinfo(self):
        """
        Perform the DNS lookup and return a list of potential socket address
        tuples as returned by :py:method:`socket.getaddrinfo`.

        """
        if self.ipv6 is None:
            addrfamily = 0
        elif self.ipv6 is True:
            addrfamily = socket.AF_INET6
        else:
            addrfamily = socket.AF_INET
        return socket.getaddrinfo(
            self.address, self.port, addrfamily, socket.SOCK_STREAM,
            socket.IPPROTO_TCP)
        self.result = nagiosplugin.Critical

    def open_socket(self, addrinfo):
        """
        Open a client socket based on information in the addrinfo list of
        tuples.

        """
        for res in addrinfo:
            af, socktype, proto, canonname, sa = res
            try:
                s = socket.socket(af, socktype, proto)
            except socket.error:
                s = None
                continue
            try:
                s.connect(sa)
            except socket.error:
                s.close()
                s = None
                continue
            break
        if s is None:
            raise XmppException("could not open socket")
        return s

    def handle_xmpp_stanza(
        self, message_str, timeout=0.1, expected_state=None
    ):
        """
        Handle a single XMPP message.

        This method sends the message specified as ``message_str`` and receives
        input from the remote XMPP server. The response message is passed to
        the SAX parser.

        The method can optionally check for a specific state of the SAX content
        handler that can be specified as ``expected_state`` which needs to be
        one of the ``XMPP_STATE_*`` constants specified in the head of this
        file.

        """
        self.socket.sendall(message_str.encode('utf-8'))
        _LOG.debug("wrote %s", message_str)
        chunks = []
        while True:
            rready, wready, xready = select([self.socket], [], [], timeout)
            if self.socket in rready:
                data = self.socket.recv(4096)
                if not data:
                    break
                chunks.append(data)
            else:
                break
        xmltext = b''.join(chunks).decode('utf-8')
        _LOG.debug("read %s", xmltext)
        self.parser.feed(xmltext)
        if (
            expected_state is not None and
            self.contenthandler.state != expected_state
        ):
            raise XmppException(
                "unexpected state %s" % self.contenthandler.state)

    def start_stream(self):
        """
        Start a XMPP conversation with the server.

        """
        if self.is_server:
            self.handle_xmpp_stanza((
                "<?xml version='1.0' ?><stream:stream to='{servername}' "
                "xmlns='jabber:server' "
                "xmlns:stream='http://etherx.jabber.org/streams' "
                "version='1.0'>"
            ).format(servername=self.servername),
                expected_state=XMPP_STATE_RECEIVED_FEATURES)
        else:
            self.handle_xmpp_stanza((
                "<?xml version='1.0' ?><stream:stream to='{servername}' "
                "xmlns='jabber:client' "
                "xmlns:stream='http://etherx.jabber.org/streams' "
                "version='1.0'>"
            ).format(servername=self.servername),
                expected_state=XMPP_STATE_RECEIVED_FEATURES)

    def setup_ssl_context(self):
        """
        Setup the SSL/TLS context.

        """
        context = ssl.create_default_context()
        context.options = ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3
        if not self.checkcerts:
            context.check_hostname = False
            context.verify_mode = ssl.CERT_NONE
        else:
            context.verify_mode = ssl.CERT_REQUIRED
            if self.caroots:
                if os.path.isfile(self.caroots):
                    kwargs = {'cafile': self.caroots}
                else:
                    kwargs = {'capath': self.caroots}
                context.load_verify_locations(**kwargs)
            else:
                context.load_default_certs()
        return context

    def initiate_tls(self):
        """
        Switch the client socket to TLS by performing a TLS handshake. If
        certificate checks are requested the method checks the certificate
        validity.

        """
        _LOG.debug("start initiate_tls()")
        self.handle_xmpp_stanza(
            "<starttls xmlns='{xmlns}'/>".format(xmlns=NS_IETF_XMPP_TLS),
            expected_state=XMPP_STATE_PROCEED_STARTTLS)
        sslcontext = self.setup_ssl_context()
        try:
            self.socket = sslcontext.wrap_socket(
                self.socket, server_hostname=self.servername)
            _LOG.info("TLS socket setup successful")
        except ssl.SSLError as ssle:
            raise XmppException("SSL error %s" % ssle.strerror)
        except ssl.CertificateError as certerr:
            raise XmppException("Certificate error %s" % certerr)
        self.starttls = False
        # reset infos retrieved previously as written in RFC 3920 sec. 5.
        self.parser.reset()
        if self.checkcerts:
            certinfo = self.socket.getpeercert()
            _LOG.debug("got the following certificate info: %s", certinfo)
            _LOG.info(
                "certificate is valid from %s until %s",
                certinfo['notBefore'], certinfo['notAfter'])
            enddate = ssl.cert_time_to_seconds(certinfo['notAfter'])
            remaining = datetime.fromtimestamp(enddate) - datetime.now()
            self.daysleft = remaining.days
        # start new parsing
        self.make_parser()
        self.set_content_handler()
        self.start_stream()
        if not self.contenthandler.is_valid_start():
            raise XmppException("no valid response to XMPP client request")
        _LOG.debug("end initiate_tls()")

    def handle_xmpp(self):
        """
        Perform the XMPP conversation.

        """
        _LOG.debug("start handle_xmpp()")
        self.start_stream()
        if not self.contenthandler.is_valid_start():
            raise XmppException("no valid response to XMPP client request")
        if self.starttls is True or self.contenthandler.tlsrequired:
            self.initiate_tls()
        self.handle_xmpp_stanza("</stream:stream>")
        _LOG.debug("end handle_xmpp()")

    def probe(self):
        """
        Entry point for the nagiosplugin.Resource APIs. Wraps the call to
        :py:method:`Xmpp.handle_xmpp` and handles errors appropriately.

        """
        start = datetime.now()
        _LOG.debug("start probe() at %s", start)
        try:
            addrinfo = self.get_addrinfo()
            self.socket = self.open_socket(addrinfo)
            try:
                self.handle_xmpp()
            finally:
                self.socket.close()
            self.parser.close()
        except socket.gaierror as e:
            self.state = nagiosplugin.Critical
            self.cause = str(e)
            _LOG.debug("got an gaierror %s", e)
            return nagiosplugin.Metric("time", "unknown")
        except XmppException as e:
            self.state = nagiosplugin.Critical
            self.cause = e.message
            _LOG.debug("got an XmppException %s", e)
            yield nagiosplugin.Metric('time', 'unknown')
            return nagiosplugin.Metric('daysleft', 'unknown')
        end = datetime.now()
        _LOG.debug("end probe() at %s", end)
        yield nagiosplugin.Metric(
            'time', (end - start).total_seconds(), 's', min=0)
        yield nagiosplugin.Metric('daysleft', self.daysleft, 'd')


class XmppContext(nagiosplugin.ScalarContext):
    """
    Context for checking the connection time and potential errors that occurred
    during the XMPP check.

    """

    def evaluate(self, metric, resource):
        """
        Evaluate the time contained in the ``metric`` or the resource state in
        case of an exception that has been handled during the check.

        """
        if resource.cause:
            return nagiosplugin.Result(resource.state, resource.cause, metric)
        return super(XmppContext, self).evaluate(metric, resource)


class DaysValidContext(nagiosplugin.Context):
    """
    Context for checking the certificate expiry date.

    """
    fmt_hint = "less than {value} days"

    def __init__(
        self, name, warndays=0, critdays=0,
        fmt_metric='certificate valid for {value} days'
    ):
        super(DaysValidContext, self).__init__(name, fmt_metric=fmt_metric)
        self.warning = nagiosplugin.Range('@%d:' % warndays)
        self.critical = nagiosplugin.Range('@%d:' % critdays)
        self.warndays = warndays
        self.critdays = critdays

    def evaluate(self, metric, resource):
        if resource.checkcerts and metric.value is not None:
            if self.critical.match(metric.value):
                return nagiosplugin.Result(
                    nagiosplugin.Critical,
                    hint=self.fmt_hint.format(value=self.critdays),
                    metric=metric)
            if self.warning.match(metric.value):
                return nagiosplugin.Result(
                    nagiosplugin.Warn,
                    hint=self.fmt_hint.format(value=self.warndays),
                    metric=metric)
            return nagiosplugin.Result(
                nagiosplugin.Ok, "", metric)
        return nagiosplugin.Result(nagiosplugin.Ok)

    def performance(self, metric, resource):
        if resource.checkcerts and metric.value is not None:
            return nagiosplugin.Performance('daysvalid', metric.value, '')
        return None


class XmppSummary(nagiosplugin.Summary):
    """
    Summary instance that outputs all metrics if the check results are ok.

    """

    def ok(self, results):
        return ", ".join([str(res) for res in results])


@nagiosplugin.guarded
def main():
    """
    Handle command line parameters and initiate the check.

    """
    import argparse
    parser = argparse.ArgumentParser(description="Check XMPP services")
    host_address = parser.add_mutually_exclusive_group(required=True)
    host_address.add_argument("-H", "--host-address", help="host address")
    host_address.add_argument(
        "--hostname", help="host name, alternative for host-address",
        dest="host_address")
    parser.add_argument(
        "-p", "--port", help="port", type=int)
    is_server = parser.add_mutually_exclusive_group()
    is_server.add_argument(
        "--s2s", dest="is_server", action='store_true',
        help="server to server (s2s)")
    is_server.add_argument(
        "--c2s", dest="is_server", action='store_false',
        help="client to server (c2s)")
    ipv6 = parser.add_mutually_exclusive_group()
    ipv6.add_argument(
        "-4", "--ipv4", dest="ipv6", action='store_false',
        help="enforce IPv4")
    ipv6.add_argument(
        "-6", "--ipv6", dest="ipv6", action='store_true',
        help="enforce IPv6")
    parser.add_argument(
        "--servername", help="server name to be used")
    parser.add_argument(
        "--starttls",
        action='store_true', help="check whether the service allows starttls")
    parser.add_argument(
        "-w", "--warning", metavar="SECONDS", default='',
        help="return warning if connection setup takes longer than SECONDS")
    parser.add_argument(
        "-c", "--critical", metavar="SECONDS", default='',
        help="return critical if connection setup takes longer than SECONDS")
    parser.add_argument(
        "--no-check-certificates", dest='checkcerts', action='store_false',
        help="do not check whether the XMPP server's certificate is valid")
    parser.add_argument(
        "-r", "--ca-roots", dest="caroots",
        help="path to a file or directory where CA certifcates can be found")
    parser.add_argument(
        "--warn-days", dest="warndays", type=int, default=0,
        help=(
            "set state to WARN if the certificate is valid for not more than "
            "the given number of days"))
    parser.add_argument(
        "--crit-days", dest="critdays", type=int, default=0,
        help=(
            "set state to CRITICAL if the certificate is valid for not more "
            "than the given number of days"))
    parser.add_argument(
        '-v', '--verbose', action='count', default=0)
    parser.set_defaults(is_server=False, ipv6=None, checkcerts=True)
    args = parser.parse_args()
    if args.port is None:
        if args.is_server:
            args.port = 5269
        else:
            args.port = 5222
    if args.servername is None:
        args.servername = args.host_address
    kwargs = vars(args)
    warning, critical, warndays, critdays, verbose = [
        kwargs.pop(arg) for arg in
        ('warning', 'critical', 'warndays', 'critdays', 'verbose')
    ]
    check = nagiosplugin.Check(
        Xmpp(**kwargs),
        XmppContext(
            'time', warning, critical, fmt_metric="request took {value}{uom}"),
        DaysValidContext('daysleft', warndays, critdays),
        XmppSummary(),
    )
    check.main(verbose=verbose, timeout=0)


if __name__ == "__main__":
    main()
