from base64 import b64decode, b64encode
from flask import Response
from flask import _request_ctx_stack as stack
from flask import make_response
from flask import request
from functools import wraps
from socket import gethostname
from os import environ
import gssapi

_srvName = None
_srvCred = None
_log = None


def init_kerberos(logger, service='HTTP', hostname=gethostname()):
    '''
    Configure the GSSAPI service name, and validate the presence of the
    appropriate principal in the kerberos keytab.

    :param app: a flask application
    :type app: flask.Flask
    :param service: GSSAPI service name
    :type service: str
    :param hostname: hostname the service runs under
    :type hostname: str
    '''
    global _srvName, _log
    _log = logger
    srvName = gssapi.Name("{}@{}".format(service, hostname), name_type=gssapi.NameType.hostbased_service)
    _srvName = srvName.canonicalize(gssapi.MechType.kerberos)

    if 'KRB5_KTNAME' not in environ:
        _log.warn("krb5auth: set KRB5_KTNAME to your keytab file")
    else:
        try:
            global _srvCred
            _srvCred = gssapi.Credentials(name = _srvName, usage = 'accept')
        except Exception as ex:
            _log.warn("krb5auth: {}.".format(ex))
        else:
            _log.info("krb5auth: server credentials: {}.".format(_srvCred.name))


def _unauthorized():
    '''
    Indicate that authentication is required
    '''
    return Response('Unauthorized', 401, {'WWW-Authenticate': 'Negotiate'})


def _forbidden():
    '''
    Indicate a complete authentication failure
    '''
    return Response('Forbidden', 403)


def requires_authentication(function):
    '''
    Require that the wrapped view function only be called by users
    authenticated with Kerberos. The view function will have the authenticated
    users principal passed to it as its first argument.

    :param function: flask view function
    :type function: function
    :returns: decorated function
    :rtype: function
    '''
    @wraps(function)
    def decorated(*args, **kwargs):
        header = request.headers.get("Authorization", "")
        if header.startswith('Negotiate '):
            global _srvCred, _log
            _log.debug("krb5auth: got negotiation.")
            token = b64decode(header[10:])

            # Init server.
            srvCtx  = gssapi.SecurityContext(creds = _srvCred, usage = 'accept')
            try:
                srvToken = srvCtx.step(token)
            except Exception as ex:
                _log.info(ex)
                return _forbidden()

            if srvCtx.complete:
                _log.debug("krb5auth: negotiation complete.")
                clientCredsToken = None
                if srvCtx.delegated_creds:
                    clientCredsToken = srvCtx.delegated_creds.export()
                response = function(srvCtx.initiator_name, clientCredsToken, *args, **kwargs)
                response = make_response(response)
            else:
                _log.debug("krb5auth: negotiation incomplete.")
                response = _unauthorised()
            if srvToken:
                response.headers['WWW-Authenticate'] = 'Negotiate ' + b64encode(srvToken).decode("utf-8")
            return response

        _log.debug("krb5auth: No negotiation found.")
        return _unauthorized()
    return decorated