/*
 * Copyright (C) 2009 Chase Douglas
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 2 as published
 * by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License along
 * with this program; if not, write to the Free Software Foundation, Inc.,
 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * In addition, as a special exception, the copyright holders give
 * permission to link the code of portions of this program with the
 * OpenSSL library under certain conditions as described in each
 * individual source file, and distribute linked combinations
 * including the two.
 * You must obey the GNU General Public License in all respects
 * for all of the code used other than OpenSSL.  If you modify
 * file(s) with this exception, you may extend this exception to your
 * version of the file(s), but you are not obligated to do so.  If you
 * do not wish to do so, delete this exception statement from your
 * version.
 */

#include <errno.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/socket.h>

#include <QPair>
#include <QHostAddress>
#include <QSslConfiguration>

#include "Connection.h"
#include "macros.h"

Connection::Connection(int sock, QSslConfiguration &configuration) :
    state(BARE),
    communicating(FALSE),
    type(SERVER),
    credentialsSize(0) {
    socket.setSocketDescriptor(sock);
    socket.setSslConfiguration(configuration);

    connect(&socket, SIGNAL(connected()), SLOT(sockConnected()));
    connect(&socket, SIGNAL(disconnected()), SLOT(sockDisconnected()));
    connect(&socket, SIGNAL(error(QAbstractSocket::SocketError)), SLOT(sockError()));
    connect(&socket, SIGNAL(sslErrors(const QList<QSslError> &)), SLOT(sslErrors(const QList<QSslError> &)));
    connect(&socket, SIGNAL(encrypted()), SLOT(sockEncrypted()));
    connect(&socket, SIGNAL(readyRead()), SLOT(sockReadyRead()));
    connect(&socket, SIGNAL(bytesWritten(qint64)), SLOT(bytesWritten(qint64)));
    connect(this, SIGNAL(disconnectSignal()), SLOT(disconnect()));

    int yes = 1;
    if (setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, &yes, sizeof(yes))) {
        qWarning("Warning: Keepalive socket option set failed, dead connections will persist: %s", strerror(errno));
    }

    int time = 1;
#ifdef TCP_KEEPIDLE
    if (setsockopt(sock, SOL_TCP, TCP_KEEPIDLE, &time, sizeof(time))) {
#else
    if (setsockopt(sock, IPPROTO_TCP, TCP_KEEPALIVE, &time, sizeof(time))) {
#endif
        qWarning("Warning: Keepalive idle time could not be set: %s", strerror(errno));
    }
}

Connection::Connection(QString &hostname, quint16 port) :
    state(BARE),
    communicating(FALSE),
    type(CLIENT) {
    connect(&socket, SIGNAL(connected()), SLOT(sockConnected()));
    connect(&socket, SIGNAL(disconnected()), SLOT(sockDisconnected()));
    connect(&socket, SIGNAL(error(QAbstractSocket::SocketError)), SLOT(sockError()));
    connect(&socket, SIGNAL(sslErrors(const QList<QSslError> &)), SLOT(sslErrors(const QList<QSslError> &)));
    connect(&socket, SIGNAL(encrypted()), SLOT(sockEncrypted()));
    connect(&socket, SIGNAL(readyRead()), SLOT(sockReadyRead()));
    connect(&socket, SIGNAL(bytesWritten(qint64)), SLOT(bytesWritten(qint64)));
    connect(this, SIGNAL(disconnectSignal()), SLOT(disconnect()));

    int yes = 1;
    if (setsockopt(socket.socketDescriptor(), SOL_SOCKET, SO_KEEPALIVE, &yes, sizeof(yes))) {
        qWarning("Warning: Keepalive socket option set failed, dead connections will persist: %s", strerror(errno));
    }

    int time = 1;
#ifdef TCP_KEEPIDLE
    if (setsockopt(socket.socketDescriptor(), SOL_TCP, TCP_KEEPIDLE, &time, sizeof(time))) {
#else
    if (setsockopt(socket.socketDescriptor(), IPPROTO_TCP, TCP_KEEPALIVE, &time, sizeof(time))) {
#endif
        qWarning("Warning: Keepalive idle time could not be set: %s", strerror(errno));
    }

    socket.connectToHost(hostname, port);
}

void Connection::sockConnected() {
    emit connected();
}

void Connection::sockDisconnected() {
    qWarning("%s disconnected", qPrintable(socket.peerAddress().toString()));
    emit disconnected();
}

void Connection::startEncryption() {
    if (type == CLIENT) {
        socket.startClientEncryption();
    }
    else {
        socket.startServerEncryption();
    }
    qDebug("Starting encryption for %s", qPrintable(socket.peerAddress().toString()));
}

void Connection::sockEncrypted() {
    state = ENCRYPTED;

    qDebug("Connection from %s is encrypted", qPrintable(socket.peerAddress().toString()));

    emit encrypted();
}

void Connection::startAuthentication() {
    state = AUTHENTICATING;

    if (type == CLIENT) {
        QByteArray message;
        message.append('\0');
        message.append(CONFIG("user").toString());
        message.append('\0');
        message.append(CONFIG("pass").toString());
        message.append('\0');
        if (message.length() > 255) {
            qCritical("Error: Credentials are too long");
            disconnect();
            return;
        }
        message.prepend((unsigned char)message.length());
        sockWrite(message);
    }

    qDebug("Starting authentication with %s", qPrintable(socket.peerAddress().toString()));
}

void Connection::startCommunication() {
    communicating = TRUE;

    if (socket.bytesAvailable()) {
        sockReadyRead();
    }

    qDebug("Starting communication with %s", qPrintable(socket.peerAddress().toString()));
}

QHostAddress Connection::peerAddress() {
    return socket.peerAddress();
}

void Connection::sockReadyRead() {
    if (state == AUTHENTICATING) {
        if (type == CLIENT) {
            char message;
            qint64 ret = socket.read(&message, 1);
            if (ret < 0) {
                qCritical("Error: Failed to receive authentication message from server %s", qPrintable(socket.peerAddress().toString()));
                disconnect();
            }
            else if (ret == 1 && message == 'a') {
                state = AUTHENTICATED;
                qWarning("Authenticated to server %s", qPrintable(socket.peerAddress().toString()));
                emit authenticated();
            }
            else {
                qCritical("Error: Failed to authenticate to server %s", qPrintable(socket.peerAddress().toString()));
            }
        }
        else {
            if (!credentialsSize) {
                qint64 ret = socket.read((char *)&credentialsSize, 1);
                if (ret < 0 || ret > 1) {
                    qCritical("Failed to read size of credentials from client %s", qPrintable(socket.peerAddress().toString()));
                    disconnect();
                    return;
                }
                if (ret == 0) {
                    return;
                }
                if (credentialsSize < 5) {
                    qCritical("Size of credentials from client %s invalid", qPrintable(socket.peerAddress().toString()));
                    disconnect();
                    return;
                }
            }

            buffer += socket.read(credentialsSize - buffer.length());

            if (buffer.length() == credentialsSize) {
                qDebug("Read credentials from client %s", qPrintable(socket.peerAddress().toString()));

                if (buffer.at(0) != '\0' || buffer.count('\0') != 3) {
                    qCritical("Error: Credentials format incorrect");
                    disconnect();
                    return;
                }
                buffer.remove(0, 1);
                int index = buffer.indexOf('\0');
                if (index < 0) {
                    qCritical("Error: Credentials format incorrect");
                    buffer.clear();
                    disconnect();
                    return;
                }
                if (buffer.at(buffer.length() - 1) != '\0') {
                    qCritical("Error: Credentials format incorrect");
                    disconnect();
                    return;
                }
                buffer.remove(buffer.length() - 1, 1);
                QByteArray user(buffer.left(index));
                QByteArray pass(buffer.right(buffer.length() - (index + 1)));
                buffer.fill(0);

                emit checkPass(user, pass);
                user.fill(0);
                pass.fill(0);
            }
        }
    }
    else if (communicating) {
        emit readyRead();
    }
}

void Connection::sockError() {
    if (socket.error() == QAbstractSocket::RemoteHostClosedError) {
        return;
    }

    qCritical("Error from %s: %s", qPrintable(socket.peerAddress().toString()), qPrintable(socket.errorString()));
    if (socket.error() == QAbstractSocket::ConnectionRefusedError) {
        emit disconnected();
    }
    else {
        disconnect();
    }
}

void Connection::sslErrors(const QList<QSslError> &errors) {
    bool ignore = TRUE;
    for (int i = 0; i < errors.count(); i++) {
        switch(errors[i].error()) {
            case QSslError::HostNameMismatch:
            case QSslError::SelfSignedCertificate:
                qCritical("Security Layer Warning for %s: %s", qPrintable(socket.peerAddress().toString()), qPrintable(errors[i].errorString()));
                break;
            default:
                qCritical("An SSL error from %s occurred: %s", qPrintable(socket.peerAddress().toString()), qPrintable(errors[i].errorString()));
                ignore = FALSE;
                break;
        }
    }

    if (ignore) {
        socket.ignoreSslErrors();
    }
    else {
        qCritical("Exiting connection due to error(s)");
        disconnect();
    }
}

void Connection::bytesWritten(qint64 size) {
    qDebug("%lld bytes written to %s", size, qPrintable(socket.peerAddress().toString()));
}

void Connection::checkPassResult(bool ok) {
    if (ok) {
        state = AUTHENTICATED;
        sockWrite("a");
        qWarning("Client %s authenticated", qPrintable(socket.peerAddress().toString()));
        emit authenticated();
    }
    else {
        sockWrite("i");
        qCritical("Error: Client %s failed to authenticate", qPrintable(socket.peerAddress().toString()));
        disconnect();
    }
}

bool Connection::sockWrite(const char *data, qint64 maxSize) {
    qint64 written = 0;

    qDebug("Writing %lld bytes: '%s'", maxSize, data);

    while (written < maxSize) {
        qint64 ret;
        ret = socket.write(data + written, maxSize - written);
        if (ret < 0) {
            qWarning("Failed to send complete data buffer");
            return false;
        }
        written += ret;
    }

    return true;
}

bool Connection::sockWrite(const char *data) {
    return sockWrite(data, qstrlen(data));
}

bool Connection::sockWrite(const QByteArray &byteArray) {
    return sockWrite(byteArray.data(), byteArray.length());
}

qint64 Connection::bytesAvailable() {
    if (state == AUTHENTICATING) {
        return 0;
    }
    return socket.bytesAvailable();
}

qint64 Connection::peek(char *data, qint64 maxSize) {
    if (state == AUTHENTICATING) {
        return 0;
    }
    return socket.peek(data, maxSize);
}

qint64 Connection::read(char *data, qint64 maxSize) {
    if (state == AUTHENTICATING) {
        return 0;
    }
    return socket.read(data, maxSize);
}

QByteArray Connection::read(qint64 maxSize) {
    if (state == AUTHENTICATING) {
        return QByteArray::QByteArray();
    }
    return socket.read(maxSize);
}

QByteArray Connection::readAll() {
    if (state == AUTHENTICATING) {
        return QByteArray::QByteArray();
    }
    return socket.readAll();
}

bool Connection::write(const QByteArray &message) {
    if (state == AUTHENTICATING) {
        return false;
    }
    return sockWrite(message);
}

void Connection::disconnect() {
    qWarning("Closing connection to %s", qPrintable(socket.peerAddress().toString()));
    socket.flush();
    socket.disconnectFromHost();
}

void Connection::disconnectLater() {
    emit disconnectSignal();
}

Connection::~Connection() {
    disconnect();
}
