/*
 *  Copyright (c) 2008 Cyrille Berger <cberger@cberger.net>
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation;
 * either version 2, or (at your option) any later version of the License.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this library; see the file COPYING.  If not, write to
 * the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
 * Boston, MA 02110-1301, USA.
 */

#include "Statement.h"

// LLVM
#include <llvm/ADT/ArrayRef.h>
#include <llvm/BasicBlock.h>
#include <llvm/Constants.h>
#include <llvm/Function.h>
#include <llvm/GlobalVariable.h>
#include <llvm/Instructions.h>
#include <llvm/Module.h>

// GTLCore
#include <GTLCore/LLVMBackend/CodeGenerator_p.h>
#include <GTLCore/LLVMBackend/ExpressionResult_p.h>
#include <GTLCore/Function.h>
#include <GTLCore/Type.h>
#include <GTLCore/Utils_p.h>
#include <GTLCore/VariableNG_p.h>
#include <GTLCore/LLVMBackend/Visitor_p.h>
#include <GTLCore/LLVMBackend/ExpressionGenerationContext_p.h>

// GTLCore
#include <GTLCore/Debug.h>

// AST
#include "Expression.h"
#include "../Function_p.h"

using namespace GTLCore::AST;

llvm::BasicBlock* Statement::createBlock( LLVMBackend::GenerationContext& context) const
{
  return context.createBasicBlock();
}

llvm::BasicBlock* StatementsList::generateStatement( LLVMBackend::GenerationContext& _context, llvm::BasicBlock* _bb ) const
{
  llvm::BasicBlock* currentBlock = _bb;
  for( std::list<Statement*>::const_iterator it = m_list.begin();
       it != m_list.end(); ++it)
  {
    currentBlock = (*it)->generateStatement( _context, currentBlock);
    if( not dynamic_cast<ReturnStatement*>(*it))
    {
      currentBlock = _context.flushDelayedStatement( currentBlock );
    }
  }
  return currentBlock;
}

StatementsList::~StatementsList()
{
  deleteAll( m_list );
}

bool StatementsList::isReturnStatement() const
{
  if( m_list.empty() ) return false;
  return (*--m_list.end())->isReturnStatement();
}

void StatementsList::appendStatement( Statement* _statement )
{
  m_list.push_back( _statement );
}

//---------------------------------------------------//
//--------------- VariableDeclaration ---------------//
//---------------------------------------------------//

VariableDeclaration::VariableDeclaration( const GTLCore::Type* type, Expression* initialiser, bool constant, const std::list<Expression*>& _initialSizes) : m_variable( new GTLCore::VariableNG(type, constant, false) ), m_initialiser(initialiser), m_initialSizes(_initialSizes), m_functionInitialiser(0)
{
  GTL_ASSERT( (not _initialSizes.empty() and type->dataType() == GTLCore::Type::ARRAY ) or _initialSizes.empty() );
}

VariableDeclaration::~VariableDeclaration()
{
  delete m_variable;
  delete m_initialiser;
  deleteAll( m_initialSizes );
  delete m_functionInitialiser;
}

llvm::BasicBlock* VariableDeclaration::generateStatement( LLVMBackend::GenerationContext& _context, llvm::BasicBlock* _bb ) const
{
  LLVMBackend::ExpressionResult initialiserValue;
  LLVMBackend::ExpressionGenerationContext egc(_bb);
  if( m_initialiser )
  {
    initialiserValue = m_initialiser->generateValue( _context, egc);
  }
  std::list<llvm::Value*> initialSizeValues;
  if( not m_initialSizes.empty() )
  {
    for( std::list<Expression*>::const_iterator it = m_initialSizes.begin();
         it != m_initialSizes.end(); ++it)
    {
      if( *it )
      {
        initialSizeValues.push_back( (*it)->generateValue( _context, egc).value() );
      } else {
        initialSizeValues.push_back( _context.codeGenerator()->integerToConstant( _context.llvmContext(), INT32_C(0) ) );
      }
    }
  }
  _bb = m_variable->initialise( _context, egc.currentBasicBlock(), initialiserValue, initialSizeValues);
  if(m_functionInitialiser)
  {
    return m_functionInitialiser->generateStatement( _context, _bb );
  }
  return _bb;
}

//---------------------------------------------------//
//------------------- IfStatement -------------------//
//---------------------------------------------------//
IfStatement::~IfStatement()
{
  delete m_expression;
  delete m_ifStatement;
}

llvm::BasicBlock* IfStatement::generateStatement( LLVMBackend::GenerationContext& _context, llvm::BasicBlock* _bb ) const
{
  LLVMBackend::ExpressionGenerationContext egc(_bb);
  llvm::Value* test = m_expression->generateValue( _context, egc ).value();
  llvm::BasicBlock* startAction = createBlock( _context );
  llvm::BasicBlock* endAction = m_ifStatement->generateStatement( _context, startAction );
  llvm::BasicBlock* after = createBlock( _context );
  _context.codeGenerator()->createIfStatement( egc.currentBasicBlock(), test,
                                               m_expression->type(), startAction, endAction, after);
  return after;
}

IfElseStatement::~IfElseStatement()
{
  delete m_expression;
  delete m_ifStatement;
  delete m_elseStatement;
}

llvm::BasicBlock* IfElseStatement::generateStatement( LLVMBackend::GenerationContext& _context, llvm::BasicBlock* _bb) const
{
  LLVMBackend::ExpressionGenerationContext egc(_bb);  
  llvm::Value* test = m_expression->generateValue( _context, egc ).value();
  llvm::BasicBlock* startAction = createBlock( _context );
  llvm::BasicBlock* endAction = m_ifStatement->generateStatement( _context, startAction );
  llvm::BasicBlock* startElseAction = createBlock( _context );
  llvm::BasicBlock* endElseAction = m_elseStatement->generateStatement( _context, startElseAction );
  llvm::BasicBlock* after = createBlock( _context );
  _context.codeGenerator()->createIfElseStatement( egc.currentBasicBlock(), test, m_expression->type(), startAction, endAction, startElseAction, endElseAction, after );
  return after;
}

ForStatement::~ForStatement()
{
  delete m_initStatement;
  delete m_testExpression;
  delete m_updateExpression;
  delete m_forStatement;
}

llvm::BasicBlock* ForStatement::generateStatement( LLVMBackend::GenerationContext& _context, llvm::BasicBlock* _bb) const
{
  // Generate the init block
  llvm::BasicBlock* initBlock = _bb;
  if( m_initStatement )
  {
    initBlock = m_initStatement->generateStatement( _context, _bb);
  }
  // Generate the test block
  llvm::BasicBlock* testBlock = createBlock( _context );
  LLVMBackend::ExpressionGenerationContext egc(testBlock);
  llvm::Value* test = m_testExpression->generateValue( _context, egc ).value();
  // Generate the update block
  llvm::BasicBlock* updateBlock = createBlock( _context );
  if( m_updateExpression )
  {
    m_updateExpression->generateStatement( _context, updateBlock );
  }
  llvm::BasicBlock* startAction = createBlock( _context );
  llvm::BasicBlock* endAction = m_forStatement->generateStatement( _context, startAction );
  llvm::BasicBlock* after = createBlock( _context );
  _context.codeGenerator()->createForStatement( initBlock, testBlock, egc.currentBasicBlock(),  test,  m_testExpression->type(), updateBlock, startAction, endAction, after);
  return after;
}

WhileStatement::~WhileStatement()
{
  delete m_expression;
  delete m_whileStatement;
}

llvm::BasicBlock* WhileStatement::generateStatement( LLVMBackend::GenerationContext& _context, llvm::BasicBlock* _bb) const
{
  llvm::BasicBlock* testBlock = createBlock( _context );
  LLVMBackend::ExpressionGenerationContext egc(testBlock);
  llvm::Value* test = m_expression->generateValue( _context, egc ).value();
  GTL_ASSERT( egc.currentBasicBlock() == testBlock);
  llvm::BasicBlock* startAction = createBlock( _context );
  llvm::BasicBlock* endAction = m_whileStatement->generateStatement( _context, startAction );
  llvm::BasicBlock* after = createBlock( _context );
  _context.codeGenerator()->createWhileStatement( _bb, testBlock, test, m_expression->type(), startAction, endAction, after );
  return after;
}

//------------------------- ReturnStatement -------------------------//

ReturnStatement::ReturnStatement( Expression* _returnExpr, Statement* _garbageCollectionStatement ) : m_returnExpr( _returnExpr ), m_garbageCollectionStatement( _garbageCollectionStatement )
{
  GTL_ASSERT( m_garbageCollectionStatement );
  if( m_returnExpr )
  {
    m_returnExpr->markAsReturnExpression();
  }
}

ReturnStatement::~ReturnStatement()
{
  delete m_returnExpr;
  delete m_garbageCollectionStatement;
}

llvm::BasicBlock* ReturnStatement::generateStatement( LLVMBackend::GenerationContext& _context, llvm::BasicBlock* _bb) const
{
  // FIXME share the cleanup code
  if( m_returnExpr )
  {
    LLVMBackend::ExpressionGenerationContext egc(_bb);
    LLVMBackend::ExpressionResult result = m_returnExpr->generateValue( _context, egc);
    const GTLCore::Type* returnType = m_returnExpr->type();
    const LLVMBackend::Visitor* visitor = LLVMBackend::Visitor::getVisitorFor( returnType );
    _bb = visitor->mark( _context, egc.currentBasicBlock(), result.value(), returnType, LLVMBackend::CodeGenerator::integerToConstant( _context.llvmContext(), INT32_C(1) ) );
    _bb = _context.flushDelayedStatement( _bb );
    _bb = m_garbageCollectionStatement->generateStatement( _context, _bb );
    llvm::Value* resultValue = result.value();
    if( m_returnExpr->type()->dataType() != Type::ARRAY
        and m_returnExpr->type()->dataType() != Type::STRUCTURE )
    {
      resultValue = _context.codeGenerator()->convertValueTo(_bb, resultValue, m_returnExpr->type(), _context.function()->returnType() );
    }
    // Mark the return value
    _bb = visitor->mark( _context, _bb, resultValue, returnType, LLVMBackend::CodeGenerator::integerToConstant(  _context.llvmContext(), INT32_C(-1) ) );
    if(_context.function()->d->isReturnedAsPointer())
    {
      new llvm::StoreInst(resultValue, _context.returnPointer(), _bb);
      llvm::ReturnInst::Create( _context.llvmContext(), _bb );
    } else {
      llvm::ReturnInst::Create( _context.llvmContext(), resultValue, _bb);
    }
  } else {
    _bb = _context.flushDelayedStatement( _bb );
    _bb = m_garbageCollectionStatement->generateStatement( _context, _bb );
    llvm::ReturnInst::Create( _context.llvmContext(), _bb );
  }
  return _bb;
}

//------------------------- PrintStatement --------------------------//

PrintStatement::~PrintStatement()
{
  deleteAll( m_expressions );
}

llvm::BasicBlock* PrintStatement::generateStatement( LLVMBackend::GenerationContext& _context, llvm::BasicBlock* _bb1 ) const
{
  std::vector<llvm::Type*> params;
  params.push_back( llvm::Type::getInt32Ty(_context.llvmContext()));
  llvm::FunctionType* definitionType = llvm::FunctionType::get( llvm::Type::getVoidTy(_context.llvmContext()), params, true );
  llvm::Function* func = dynamic_cast<llvm::Function*>( _context.llvmModule()->getOrInsertFunction("print", definitionType));
  
  std::vector<llvm::Value*> values;
  values.push_back( _context.codeGenerator()->integerToConstant( _context.llvmContext(), gtl_uint32(m_expressions.size()) ));
  
  LLVMBackend::ExpressionGenerationContext egc(_bb1);
  for( std::list<AST::Expression*>::const_iterator it = m_expressions.begin();
           it != m_expressions.end(); ++it)
  {
    LLVMBackend::ExpressionResult value = (*it)->generateValue( _context, egc);
    const llvm::Type* type = value.value()->getType();
    if( (*it)->type() == 0 )
    { // It's a string
      values.push_back( _context.codeGenerator()->integerToConstant( _context.llvmContext(), INT32_C(3)) );
      values.push_back( new llvm::GlobalVariable( *_context.llvmModule(), value.value()->getType(), true, llvm::GlobalValue::InternalLinkage, value.constant(), "" ) );
    } else if( type == llvm::Type::getInt32Ty(_context.llvmContext()) )
    {
      values.push_back( _context.codeGenerator()->integerToConstant(  _context.llvmContext(), INT32_C(0)) );
      values.push_back( value.value() );
    } else if( type == llvm::Type::getFloatTy(_context.llvmContext()) )
    {
      values.push_back( _context.codeGenerator()->integerToConstant(  _context.llvmContext(), INT32_C(1)) );
      values.push_back( _context.codeGenerator()->convertValueTo( egc.currentBasicBlock(), value.value(), (*it)->type(), GTLCore::Type::Float64 ));
    } else if( type == llvm::Type::getInt1Ty(_context.llvmContext() ) )
    {
      values.push_back( _context.codeGenerator()->integerToConstant(  _context.llvmContext(), INT32_C(2)) );
      values.push_back( _context.codeGenerator()->convertValueTo( egc.currentBasicBlock(), value.value(), (*it)->type(), GTLCore::Type::Integer32 ));
    } else {
      GTL_DEBUG("Unknown type for print " << *type);
    }
  }
  llvm::CallInst::Create(func, values, "", egc.currentBasicBlock());
  return egc.currentBasicBlock();
}
