/* Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements. 
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/**
 * @file
 *
 * @Author christian liesch <liesch@gmx.ch>
 *
 * Implementation of the HTTP Test Tool ssl.
 */

/************************************************************************
 * Includes
 ***********************************************************************/
#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include <openssl/ssl.h>
#include <openssl/rand.h>
#include <openssl/err.h>

#include <apr.h>
#include <apr_strings.h>
#include <apr_file_io.h>
#include <apr_portable.h>
#include <apr_errno.h>

#if APR_HAVE_UNISTD_H
#include <unistd.h> /* for getpid() */
#endif

#ifndef RAND_MAX
#include <limits.h>
#define RAND_MAX INT_MAX
#endif

#include "defines.h"
#include "ssl.h"

#ifdef USE_SSL

/************************************************************************
 * Definitions 
 ***********************************************************************/


/************************************************************************
 * Forward declaration 
 ***********************************************************************/

static unsigned long ssl_util_thr_id(void);
static void ssl_util_thr_lock(int mode, int type, const char *file, int line); 
static int ssl_rand_choosenum(int l, int h); 
static apr_status_t ssl_util_thread_cleanup(void *data); 

/************************************************************************
 * Implementation
 ***********************************************************************/

/**
 * To ensure thread-safetyness in OpenSSL - work in progress
 */
static apr_thread_mutex_t **lock_cs;
static int lock_num_locks;

/**
 * Thread setup (SSL call back)
 *
 * @param p IN pool
 */
void ssl_util_thread_setup(apr_pool_t * p) {
  int i;

  lock_num_locks = CRYPTO_num_locks();
  lock_cs = apr_palloc(p, lock_num_locks * sizeof(*lock_cs));

  for (i = 0; i < lock_num_locks; i++) {
    apr_thread_mutex_create(&(lock_cs[i]), APR_THREAD_MUTEX_DEFAULT, p);
  }

  CRYPTO_set_id_callback(ssl_util_thr_id);

  CRYPTO_set_locking_callback(ssl_util_thr_lock);

  apr_pool_cleanup_register(p, NULL, ssl_util_thread_cleanup,
                            apr_pool_cleanup_null);
}

/**
 * Do a seed
 */
void ssl_rand_seed(void) {
  int nDone = 0;
  int n, l;
  time_t t;
  pid_t pid;
  unsigned char stackdata[256];

  /*
   * seed in the current time (usually just 4 bytes)
   */
  t = time(NULL);
  l = sizeof(time_t);
  RAND_seed((unsigned char *) &t, l);
  nDone += l;

  /*
   * seed in the current process id (usually just 4 bytes)
   */
  pid = getpid();
  l = sizeof(pid_t);
  RAND_seed((unsigned char *) &pid, l);
  nDone += l;

  /*
   * seed in some current state of the run-time stack (128 bytes)
   */
  n = ssl_rand_choosenum(0, sizeof(stackdata) - 128 - 1);
  RAND_seed(stackdata + n, 128);
  nDone += 128;
}

/**
 * ssl handshake client site
 *
 * @param ssl IN ssl object
 * @param error OUT error text
 *
 * @return APR_EINVAL if no ssl context or
 *         APR_ECONNREFUSED if could not handshake or
 *         APR_SUCCESS
 */
apr_status_t ssl_handshake(SSL *ssl, char **error, apr_pool_t *pool) {
  apr_status_t status = APR_SUCCESS;
  int do_next = 1;

  *error = NULL;
  
  /* check first if we have a ssl context */
  if (!ssl) {
    *error = apr_pstrdup(pool, "No ssl context");
    return APR_EINVAL;
  }
  
  while (do_next) {
    int ret, ecode;

    apr_sleep(1);
    
    ret = SSL_do_handshake(ssl);
    ecode = SSL_get_error(ssl, ret);

    switch (ecode) {
    case SSL_ERROR_NONE:
      status = APR_SUCCESS;
      do_next = 0;
      break;
    case SSL_ERROR_WANT_READ:
      /* Try again */
      do_next = 1;
      break;
    case SSL_ERROR_WANT_WRITE:
      /* Try again */
      do_next = 1;
      break;
    case SSL_ERROR_WANT_CONNECT:
    case SSL_ERROR_SSL:
    case SSL_ERROR_SYSCALL:
      *error = apr_pstrdup(pool, "Handshake failed");
      status = APR_ECONNREFUSED;
      do_next = 0;
      break;
    }
  }
  return status;
}

/**
 * ssl accept
 *
 * @param worker IN thread data object
 *
 * @return APR_SUCCESS
 */
apr_status_t ssl_accept(SSL *ssl, char **error, apr_pool_t *pool) {
  int rc;
  int err;

  *error = NULL;
  
  /* check first if we have a ssl context */
  if (!ssl) {
    *error = apr_pstrdup(pool, "No ssl context");
    return APR_EINVAL;
  }
  
tryagain:
  apr_sleep(1);
  if (SSL_is_init_finished(ssl)) {
    return APR_SUCCESS;
  }

  if ((rc = SSL_accept(ssl)) <= 0) {
    err = SSL_get_error(ssl, rc);

    if (err == SSL_ERROR_ZERO_RETURN) {
      *error = apr_pstrdup(pool, "SSL accept connection closed");
      return APR_ECONNABORTED;
    }
    else if (err == SSL_ERROR_WANT_READ) {
      *error = apr_pstrdup(pool, "SSL accept SSL_ERROR_WANT_READ.");
      goto tryagain;
    }
    else if (ERR_GET_LIB(ERR_peek_error()) == ERR_LIB_SSL &&
	     ERR_GET_REASON(ERR_peek_error()) == SSL_R_HTTP_REQUEST) {
      /*
       * The case where OpenSSL has recognized a HTTP request:
       * This means the client speaks plain HTTP on our HTTPS port.
       * ssl_io_filter_error will disable the ssl filters when it
       * sees this status code.
       */
      *error = apr_pstrdup(pool, "SSL accept client speaks plain HTTP");
      return APR_ENOTSOCK;
    }
    else if (err == SSL_ERROR_SYSCALL) {
       *error = apr_pstrdup(pool, 
                  "SSL accept interrupted by system "
                  "[Hint: Stop button pressed in browser?!]");
       return APR_ECONNABORTED;
    }
    else /* if (ssl_err == SSL_ERROR_SSL) */ {
	 /*
	  * Log SSL errors and any unexpected conditions.
          */
      *error = apr_psprintf(pool, "SSL library error %d in accept", err);
      return APR_ECONNABORTED;
    }
  }
 
  return APR_SUCCESS;
}

/**
 * This is a SSL lock call back
 *
 * @param mode IN lock mode
 * @param type IN lock type
 * @param file IN unused
 * @param line IN unused
 */
static void ssl_util_thr_lock(int mode, int type, const char *file, int line) {
  apr_status_t status;

  if (type < lock_num_locks) {
    if (mode & CRYPTO_LOCK) {
      if ((status = apr_thread_mutex_lock(lock_cs[type])) != APR_SUCCESS) {
	fprintf(stderr, "Fatal error could not lock");
	exit(status);
      }
    }
    else {
      if ((status = apr_thread_mutex_unlock(lock_cs[type])) != APR_SUCCESS) {
	fprintf(stderr, "Fatal error could not unlock");
	exit(status);
      }
    }
  }
}

/**
 * @return current thread id (SSL call back)
 */
static unsigned long ssl_util_thr_id(void) {
  /* OpenSSL needs this to return an unsigned long.  On OS/390, the pthread
   * id is a structure twice that big.  Use the TCB pointer instead as a
   * unique unsigned long.
   */
#ifdef __MVS__
  struct PSA {
    char unmapped[540];
    unsigned long PSATOLD;
  }  *psaptr = 0;

  return psaptr->PSATOLD;
#else
  return (unsigned long) apr_os_thread_current();
#endif
}

/**
 * Thread clean up function (SSL call back)
 *
 * @param data IN unused
 *
 * @return APR_SUCCESS
 */
static apr_status_t ssl_util_thread_cleanup(void *data) {
  CRYPTO_set_locking_callback(NULL);
  CRYPTO_set_id_callback(NULL);

  /* Let the registered mutex cleanups do their own thing
   */
  return APR_SUCCESS;
}

/**
 * Rand between low and high
 *
 * @param l IN bottom
 * @param h IN top value
 *
 * @return something between l and h
 */
static int ssl_rand_choosenum(int l, int h) {
  int i;
  char buf[50];

  srand((unsigned int) time(NULL));
  apr_snprintf(buf, sizeof(buf), "%.0f",
               (((double) (rand() % RAND_MAX) / RAND_MAX) * (h - l)));
  i = atoi(buf) + 1;
  if (i < l)
    i = l;
  if (i > h)
    i = h;
  return i;
}

#endif
