//===- CIRTypes.cpp - MLIR CIR Types --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the types in the CIR dialect.
//
//===----------------------------------------------------------------------===//

#include "clang/CIR/Dialect/IR/CIRTypes.h"

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "clang/Basic/AddressSpaces.h"
#include "clang/CIR/Dialect/IR/CIRAttrs.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/IR/CIRTypesDetails.h"
#include "clang/CIR/MissingFeatures.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/TypeSwitch.h"

//===----------------------------------------------------------------------===//
// CIR Helpers
//===----------------------------------------------------------------------===//
bool cir::isSized(mlir::Type ty) {
  if (auto sizedTy = mlir::dyn_cast<cir::SizedTypeInterface>(ty))
    return sizedTy.isSized();
  assert(!cir::MissingFeatures::unsizedTypes());
  return false;
}

//===----------------------------------------------------------------------===//
// CIR Custom Parser/Printer Signatures
//===----------------------------------------------------------------------===//

static mlir::ParseResult
parseFuncTypeParams(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
                    bool &isVarArg);
static void printFuncTypeParams(mlir::AsmPrinter &p,
                                mlir::ArrayRef<mlir::Type> params,
                                bool isVarArg);
//===----------------------------------------------------------------------===//
// CIR Custom Parser/Printer Signatures
//===----------------------------------------------------------------------===//

static mlir::ParseResult
parseFuncTypeParams(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
                    bool &isVarArg);

static void printFuncTypeParams(mlir::AsmPrinter &p,
                                mlir::ArrayRef<mlir::Type> params,
                                bool isVarArg);

//===----------------------------------------------------------------------===//
// AddressSpace
//===----------------------------------------------------------------------===//

mlir::ParseResult parseTargetAddressSpace(mlir::AsmParser &p,
                                          cir::TargetAddressSpaceAttr &attr);

void printTargetAddressSpace(mlir::AsmPrinter &p,
                             cir::TargetAddressSpaceAttr attr);

//===----------------------------------------------------------------------===//
// Get autogenerated stuff
//===----------------------------------------------------------------------===//

namespace cir {

#include "clang/CIR/Dialect/IR/CIRTypeConstraints.cpp.inc"

} // namespace cir

#define GET_TYPEDEF_CLASSES
#include "clang/CIR/Dialect/IR/CIROpsTypes.cpp.inc"

using namespace mlir;
using namespace cir;

//===----------------------------------------------------------------------===//
// General CIR parsing / printing
//===----------------------------------------------------------------------===//

Type CIRDialect::parseType(DialectAsmParser &parser) const {
  llvm::SMLoc typeLoc = parser.getCurrentLocation();
  llvm::StringRef mnemonic;
  Type genType;

  // Try to parse as a tablegen'd type.
  OptionalParseResult parseResult =
      generatedTypeParser(parser, &mnemonic, genType);
  if (parseResult.has_value())
    return genType;

  // Type is not tablegen'd: try to parse as a raw C++ type.
  return StringSwitch<function_ref<Type()>>(mnemonic)
      .Case("record", [&] { return RecordType::parse(parser); })
      .Default([&] {
        parser.emitError(typeLoc) << "unknown CIR type: " << mnemonic;
        return Type();
      })();
}

void CIRDialect::printType(Type type, DialectAsmPrinter &os) const {
  // Try to print as a tablegen'd type.
  if (generatedTypePrinter(type, os).succeeded())
    return;

  // TODO(CIR) Attempt to print as a raw C++ type.
  llvm::report_fatal_error("printer is missing a handler for this type");
}

//===----------------------------------------------------------------------===//
// RecordType Definitions
//===----------------------------------------------------------------------===//

Type RecordType::parse(mlir::AsmParser &parser) {
  FailureOr<AsmParser::CyclicParseReset> cyclicParseGuard;
  const llvm::SMLoc loc = parser.getCurrentLocation();
  const mlir::Location eLoc = parser.getEncodedSourceLoc(loc);
  bool packed = false;
  bool padded = false;
  RecordKind kind;
  mlir::MLIRContext *context = parser.getContext();

  if (parser.parseLess())
    return {};

  // TODO(cir): in the future we should probably separate types for different
  // source language declarations such as cir.record and cir.union
  if (parser.parseOptionalKeyword("struct").succeeded())
    kind = RecordKind::Struct;
  else if (parser.parseOptionalKeyword("union").succeeded())
    kind = RecordKind::Union;
  else if (parser.parseOptionalKeyword("class").succeeded())
    kind = RecordKind::Class;
  else {
    parser.emitError(loc, "unknown record type");
    return {};
  }

  mlir::StringAttr name;
  parser.parseOptionalAttribute(name);

  // Is a self reference: ensure referenced type was parsed.
  if (name && parser.parseOptionalGreater().succeeded()) {
    RecordType type = getChecked(eLoc, context, name, kind);
    if (succeeded(parser.tryStartCyclicParse(type))) {
      parser.emitError(loc, "invalid self-reference within record");
      return {};
    }
    return type;
  }

  // Is a named record definition: ensure name has not been parsed yet.
  if (name) {
    RecordType type = getChecked(eLoc, context, name, kind);
    cyclicParseGuard = parser.tryStartCyclicParse(type);
    if (failed(cyclicParseGuard)) {
      parser.emitError(loc, "record already defined");
      return {};
    }
  }

  if (parser.parseOptionalKeyword("packed").succeeded())
    packed = true;

  if (parser.parseOptionalKeyword("padded").succeeded())
    padded = true;

  // Parse record members or lack thereof.
  bool incomplete = true;
  llvm::SmallVector<mlir::Type> members;
  if (parser.parseOptionalKeyword("incomplete").failed()) {
    incomplete = false;
    const auto delimiter = AsmParser::Delimiter::Braces;
    const auto parseElementFn = [&parser, &members]() {
      return parser.parseType(members.emplace_back());
    };
    if (parser.parseCommaSeparatedList(delimiter, parseElementFn).failed())
      return {};
  }

  if (parser.parseGreater())
    return {};

  // Try to create the proper record type.
  ArrayRef<mlir::Type> membersRef(members); // Needed for template deduction.
  mlir::Type type = {};
  if (name && incomplete) { // Identified & incomplete
    type = getChecked(eLoc, context, name, kind);
  } else if (!name && !incomplete) { // Anonymous & complete
    type = getChecked(eLoc, context, membersRef, packed, padded, kind);
  } else if (!incomplete) { // Identified & complete
    type = getChecked(eLoc, context, membersRef, name, packed, padded, kind);
    // If the record has a self-reference, its type already exists in a
    // incomplete state. In this case, we must complete it.
    if (mlir::cast<RecordType>(type).isIncomplete())
      mlir::cast<RecordType>(type).complete(membersRef, packed, padded);
    assert(!cir::MissingFeatures::astRecordDeclAttr());
  } else { // anonymous & incomplete
    parser.emitError(loc, "anonymous records must be complete");
    return {};
  }

  return type;
}

void RecordType::print(mlir::AsmPrinter &printer) const {
  FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrintGuard;
  printer << '<';

  switch (getKind()) {
  case RecordKind::Struct:
    printer << "struct ";
    break;
  case RecordKind::Union:
    printer << "union ";
    break;
  case RecordKind::Class:
    printer << "class ";
    break;
  }

  if (getName())
    printer << getName();

  // Current type has already been printed: print as self reference.
  cyclicPrintGuard = printer.tryStartCyclicPrint(*this);
  if (failed(cyclicPrintGuard)) {
    printer << '>';
    return;
  }

  // Type not yet printed: continue printing the entire record.
  printer << ' ';

  if (getPacked())
    printer << "packed ";

  if (getPadded())
    printer << "padded ";

  if (isIncomplete()) {
    printer << "incomplete";
  } else {
    printer << "{";
    llvm::interleaveComma(getMembers(), printer);
    printer << "}";
  }

  printer << '>';
}

mlir::LogicalResult
RecordType::verify(function_ref<mlir::InFlightDiagnostic()> emitError,
                   llvm::ArrayRef<mlir::Type> members, mlir::StringAttr name,
                   bool incomplete, bool packed, bool padded,
                   RecordType::RecordKind kind) {
  if (name && name.getValue().empty())
    return emitError() << "identified records cannot have an empty name";
  return mlir::success();
}

::llvm::ArrayRef<mlir::Type> RecordType::getMembers() const {
  return getImpl()->members;
}

bool RecordType::isIncomplete() const { return getImpl()->incomplete; }

mlir::StringAttr RecordType::getName() const { return getImpl()->name; }

bool RecordType::getIncomplete() const { return getImpl()->incomplete; }

bool RecordType::getPacked() const { return getImpl()->packed; }

bool RecordType::getPadded() const { return getImpl()->padded; }

cir::RecordType::RecordKind RecordType::getKind() const {
  return getImpl()->kind;
}

void RecordType::complete(ArrayRef<Type> members, bool packed, bool padded) {
  assert(!cir::MissingFeatures::astRecordDeclAttr());
  if (mutate(members, packed, padded).failed())
    llvm_unreachable("failed to complete record");
}

/// Return the largest member of in the type.
///
/// Recurses into union members never returning a union as the largest member.
Type RecordType::getLargestMember(const ::mlir::DataLayout &dataLayout) const {
  assert(isUnion() && "Only call getLargestMember on unions");
  llvm::ArrayRef<Type> members = getMembers();
  // If the union is padded, we need to ignore the last member,
  // which is the padding.
  return *std::max_element(
      members.begin(), getPadded() ? members.end() - 1 : members.end(),
      [&](Type lhs, Type rhs) {
        return dataLayout.getTypeABIAlignment(lhs) <
                   dataLayout.getTypeABIAlignment(rhs) ||
               (dataLayout.getTypeABIAlignment(lhs) ==
                    dataLayout.getTypeABIAlignment(rhs) &&
                dataLayout.getTypeSize(lhs) < dataLayout.getTypeSize(rhs));
      });
}

bool RecordType::isLayoutIdentical(const RecordType &other) {
  if (getImpl() == other.getImpl())
    return true;

  if (getPacked() != other.getPacked())
    return false;

  return getMembers() == other.getMembers();
}

//===----------------------------------------------------------------------===//
// Data Layout information for types
//===----------------------------------------------------------------------===//

llvm::TypeSize
PointerType::getTypeSizeInBits(const ::mlir::DataLayout &dataLayout,
                               ::mlir::DataLayoutEntryListRef params) const {
  // FIXME: improve this in face of address spaces
  assert(!cir::MissingFeatures::dataLayoutPtrHandlingBasedOnLangAS());
  return llvm::TypeSize::getFixed(64);
}

uint64_t
PointerType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
                             ::mlir::DataLayoutEntryListRef params) const {
  // FIXME: improve this in face of address spaces
  assert(!cir::MissingFeatures::dataLayoutPtrHandlingBasedOnLangAS());
  return 8;
}

llvm::TypeSize
RecordType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
                              mlir::DataLayoutEntryListRef params) const {
  if (isUnion())
    return dataLayout.getTypeSize(getLargestMember(dataLayout));

  auto recordSize = static_cast<uint64_t>(computeStructSize(dataLayout));
  return llvm::TypeSize::getFixed(recordSize * 8);
}

uint64_t
RecordType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
                            ::mlir::DataLayoutEntryListRef params) const {
  if (isUnion())
    return dataLayout.getTypeABIAlignment(getLargestMember(dataLayout));

  // Packed structures always have an ABI alignment of 1.
  if (getPacked())
    return 1;
  return computeStructAlignment(dataLayout);
}

unsigned
RecordType::computeStructSize(const mlir::DataLayout &dataLayout) const {
  assert(isComplete() && "Cannot get layout of incomplete records");

  // This is a similar algorithm to LLVM's StructLayout.
  unsigned recordSize = 0;
  uint64_t recordAlignment = 1;

  for (mlir::Type ty : getMembers()) {
    // This assumes that we're calculating size based on the ABI alignment, not
    // the preferred alignment for each type.
    const uint64_t tyAlign =
        (getPacked() ? 1 : dataLayout.getTypeABIAlignment(ty));

    // Add padding to the struct size to align it to the abi alignment of the
    // element type before than adding the size of the element.
    recordSize = llvm::alignTo(recordSize, tyAlign);
    recordSize += dataLayout.getTypeSize(ty);

    // The alignment requirement of a struct is equal to the strictest alignment
    // requirement of its elements.
    recordAlignment = std::max(tyAlign, recordAlignment);
  }

  // At the end, add padding to the struct to satisfy its own alignment
  // requirement. Otherwise structs inside of arrays would be misaligned.
  recordSize = llvm::alignTo(recordSize, recordAlignment);
  return recordSize;
}

// We also compute the alignment as part of computeStructSize, but this is more
// efficient. Ideally, we'd like to compute both at once and cache the result,
// but that's implemented yet.
// TODO(CIR): Implement a way to cache the result.
uint64_t
RecordType::computeStructAlignment(const mlir::DataLayout &dataLayout) const {
  assert(isComplete() && "Cannot get layout of incomplete records");

  // This is a similar algorithm to LLVM's StructLayout.
  uint64_t recordAlignment = 1;
  for (mlir::Type ty : getMembers())
    recordAlignment =
        std::max(dataLayout.getTypeABIAlignment(ty), recordAlignment);

  return recordAlignment;
}

uint64_t RecordType::getElementOffset(const ::mlir::DataLayout &dataLayout,
                                      unsigned idx) const {
  assert(idx < getMembers().size() && "access not valid");

  // All union elements are at offset zero.
  if (isUnion() || idx == 0)
    return 0;

  assert(isComplete() && "Cannot get layout of incomplete records");
  assert(idx < getNumElements());
  llvm::ArrayRef<mlir::Type> members = getMembers();

  unsigned offset = 0;

  for (mlir::Type ty :
       llvm::make_range(members.begin(), std::next(members.begin(), idx))) {
    // This matches LLVM since it uses the ABI instead of preferred alignment.
    const llvm::Align tyAlign =
        llvm::Align(getPacked() ? 1 : dataLayout.getTypeABIAlignment(ty));

    // Add padding if necessary to align the data element properly.
    offset = llvm::alignTo(offset, tyAlign);

    // Consume space for this data item
    offset += dataLayout.getTypeSize(ty);
  }

  // Account for padding, if necessary, for the alignment of the field whose
  // offset we are calculating.
  const llvm::Align tyAlign = llvm::Align(
      getPacked() ? 1 : dataLayout.getTypeABIAlignment(members[idx]));
  offset = llvm::alignTo(offset, tyAlign);

  return offset;
}

//===----------------------------------------------------------------------===//
// IntType Definitions
//===----------------------------------------------------------------------===//

Type IntType::parse(mlir::AsmParser &parser) {
  mlir::MLIRContext *context = parser.getBuilder().getContext();
  llvm::SMLoc loc = parser.getCurrentLocation();
  bool isSigned;
  unsigned width;

  if (parser.parseLess())
    return {};

  // Fetch integer sign.
  llvm::StringRef sign;
  if (parser.parseKeyword(&sign))
    return {};
  if (sign == "s")
    isSigned = true;
  else if (sign == "u")
    isSigned = false;
  else {
    parser.emitError(loc, "expected 's' or 'u'");
    return {};
  }

  if (parser.parseComma())
    return {};

  // Fetch integer size.
  if (parser.parseInteger(width))
    return {};
  if (width < IntType::minBitwidth() || width > IntType::maxBitwidth()) {
    parser.emitError(loc, "expected integer width to be from ")
        << IntType::minBitwidth() << " up to " << IntType::maxBitwidth();
    return {};
  }

  if (parser.parseGreater())
    return {};

  return IntType::get(context, width, isSigned);
}

void IntType::print(mlir::AsmPrinter &printer) const {
  char sign = isSigned() ? 's' : 'u';
  printer << '<' << sign << ", " << getWidth() << '>';
}

llvm::TypeSize
IntType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
                           mlir::DataLayoutEntryListRef params) const {
  return llvm::TypeSize::getFixed(getWidth());
}

uint64_t IntType::getABIAlignment(const mlir::DataLayout &dataLayout,
                                  mlir::DataLayoutEntryListRef params) const {
  return (uint64_t)(getWidth() / 8);
}

mlir::LogicalResult
IntType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
                unsigned width, bool isSigned) {
  if (width < IntType::minBitwidth() || width > IntType::maxBitwidth())
    return emitError() << "IntType only supports widths from "
                       << IntType::minBitwidth() << " up to "
                       << IntType::maxBitwidth();
  return mlir::success();
}

bool cir::isValidFundamentalIntWidth(unsigned width) {
  return width == 8 || width == 16 || width == 32 || width == 64;
}

//===----------------------------------------------------------------------===//
// Floating-point type definitions
//===----------------------------------------------------------------------===//

const llvm::fltSemantics &SingleType::getFloatSemantics() const {
  return llvm::APFloat::IEEEsingle();
}

llvm::TypeSize
SingleType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
                              mlir::DataLayoutEntryListRef params) const {
  return llvm::TypeSize::getFixed(getWidth());
}

uint64_t
SingleType::getABIAlignment(const mlir::DataLayout &dataLayout,
                            mlir::DataLayoutEntryListRef params) const {
  return (uint64_t)(getWidth() / 8);
}

const llvm::fltSemantics &DoubleType::getFloatSemantics() const {
  return llvm::APFloat::IEEEdouble();
}

llvm::TypeSize
DoubleType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
                              mlir::DataLayoutEntryListRef params) const {
  return llvm::TypeSize::getFixed(getWidth());
}

uint64_t
DoubleType::getABIAlignment(const mlir::DataLayout &dataLayout,
                            mlir::DataLayoutEntryListRef params) const {
  return (uint64_t)(getWidth() / 8);
}

const llvm::fltSemantics &FP16Type::getFloatSemantics() const {
  return llvm::APFloat::IEEEhalf();
}

llvm::TypeSize
FP16Type::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
                            mlir::DataLayoutEntryListRef params) const {
  return llvm::TypeSize::getFixed(getWidth());
}

uint64_t FP16Type::getABIAlignment(const mlir::DataLayout &dataLayout,
                                   mlir::DataLayoutEntryListRef params) const {
  return (uint64_t)(getWidth() / 8);
}

const llvm::fltSemantics &BF16Type::getFloatSemantics() const {
  return llvm::APFloat::BFloat();
}

llvm::TypeSize
BF16Type::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
                            mlir::DataLayoutEntryListRef params) const {
  return llvm::TypeSize::getFixed(getWidth());
}

uint64_t BF16Type::getABIAlignment(const mlir::DataLayout &dataLayout,
                                   mlir::DataLayoutEntryListRef params) const {
  return (uint64_t)(getWidth() / 8);
}

const llvm::fltSemantics &FP80Type::getFloatSemantics() const {
  return llvm::APFloat::x87DoubleExtended();
}

llvm::TypeSize
FP80Type::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
                            mlir::DataLayoutEntryListRef params) const {
  // Though only 80 bits are used for the value, the type is 128 bits in size.
  return llvm::TypeSize::getFixed(128);
}

uint64_t FP80Type::getABIAlignment(const mlir::DataLayout &dataLayout,
                                   mlir::DataLayoutEntryListRef params) const {
  return 16;
}

const llvm::fltSemantics &FP128Type::getFloatSemantics() const {
  return llvm::APFloat::IEEEquad();
}

llvm::TypeSize
FP128Type::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
                             mlir::DataLayoutEntryListRef params) const {
  return llvm::TypeSize::getFixed(getWidth());
}

uint64_t FP128Type::getABIAlignment(const mlir::DataLayout &dataLayout,
                                    mlir::DataLayoutEntryListRef params) const {
  return 16;
}

const llvm::fltSemantics &LongDoubleType::getFloatSemantics() const {
  return mlir::cast<cir::FPTypeInterface>(getUnderlying()).getFloatSemantics();
}

llvm::TypeSize
LongDoubleType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
                                  mlir::DataLayoutEntryListRef params) const {
  return mlir::cast<mlir::DataLayoutTypeInterface>(getUnderlying())
      .getTypeSizeInBits(dataLayout, params);
}

uint64_t
LongDoubleType::getABIAlignment(const mlir::DataLayout &dataLayout,
                                mlir::DataLayoutEntryListRef params) const {
  return mlir::cast<mlir::DataLayoutTypeInterface>(getUnderlying())
      .getABIAlignment(dataLayout, params);
}

//===----------------------------------------------------------------------===//
// ComplexType Definitions
//===----------------------------------------------------------------------===//

llvm::TypeSize
cir::ComplexType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
                                    mlir::DataLayoutEntryListRef params) const {
  // C17 6.2.5p13:
  //   Each complex type has the same representation and alignment requirements
  //   as an array type containing exactly two elements of the corresponding
  //   real type.

  return dataLayout.getTypeSizeInBits(getElementType()) * 2;
}

uint64_t
cir::ComplexType::getABIAlignment(const mlir::DataLayout &dataLayout,
                                  mlir::DataLayoutEntryListRef params) const {
  // C17 6.2.5p13:
  //   Each complex type has the same representation and alignment requirements
  //   as an array type containing exactly two elements of the corresponding
  //   real type.

  return dataLayout.getTypeABIAlignment(getElementType());
}

FuncType FuncType::clone(TypeRange inputs, TypeRange results) const {
  assert(results.size() == 1 && "expected exactly one result type");
  return get(llvm::to_vector(inputs), results[0], isVarArg());
}

// Custom parser that parses function parameters of form `(<type>*, ...)`.
static mlir::ParseResult
parseFuncTypeParams(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
                    bool &isVarArg) {
  isVarArg = false;
  return p.parseCommaSeparatedList(
      AsmParser::Delimiter::Paren, [&]() -> mlir::ParseResult {
        if (isVarArg)
          return p.emitError(p.getCurrentLocation(),
                             "variadic `...` must be the last parameter");
        if (succeeded(p.parseOptionalEllipsis())) {
          isVarArg = true;
          return success();
        }
        mlir::Type type;
        if (failed(p.parseType(type)))
          return failure();
        params.push_back(type);
        return success();
      });
}

static void printFuncTypeParams(mlir::AsmPrinter &p,
                                mlir::ArrayRef<mlir::Type> params,
                                bool isVarArg) {
  p << '(';
  llvm::interleaveComma(params, p,
                        [&p](mlir::Type type) { p.printType(type); });
  if (isVarArg) {
    if (!params.empty())
      p << ", ";
    p << "...";
  }
  p << ')';
}

/// Get the C-style return type of the function, which is !cir.void if the
/// function returns nothing and the actual return type otherwise.
mlir::Type FuncType::getReturnType() const {
  if (hasVoidReturn())
    return cir::VoidType::get(getContext());
  return getOptionalReturnType();
}

/// Get the MLIR-style return type of the function, which is an empty
/// ArrayRef if the function returns nothing and a single-element ArrayRef
/// with the actual return type otherwise.
llvm::ArrayRef<mlir::Type> FuncType::getReturnTypes() const {
  if (hasVoidReturn())
    return {};
  // Can't use getOptionalReturnType() here because llvm::ArrayRef hold a
  // pointer to its elements and doesn't do lifetime extension.  That would
  // result in returning a pointer to a temporary that has gone out of scope.
  return getImpl()->optionalReturnType;
}

// Does the fuction type return nothing?
bool FuncType::hasVoidReturn() const { return !getOptionalReturnType(); }

mlir::LogicalResult
FuncType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
                 llvm::ArrayRef<mlir::Type> argTypes, mlir::Type returnType,
                 bool isVarArg) {
  if (mlir::isa_and_nonnull<cir::VoidType>(returnType))
    return emitError()
           << "!cir.func cannot have an explicit 'void' return type";
  return mlir::success();
}

//===----------------------------------------------------------------------===//
// BoolType
//===----------------------------------------------------------------------===//

llvm::TypeSize
BoolType::getTypeSizeInBits(const ::mlir::DataLayout &dataLayout,
                            ::mlir::DataLayoutEntryListRef params) const {
  return llvm::TypeSize::getFixed(8);
}

uint64_t
BoolType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
                          ::mlir::DataLayoutEntryListRef params) const {
  return 1;
}

//===----------------------------------------------------------------------===//
//  DataMemberType Definitions
//===----------------------------------------------------------------------===//

llvm::TypeSize
DataMemberType::getTypeSizeInBits(const ::mlir::DataLayout &dataLayout,
                                  ::mlir::DataLayoutEntryListRef params) const {
  // FIXME: consider size differences under different ABIs
  assert(!MissingFeatures::cxxABI());
  return llvm::TypeSize::getFixed(64);
}

uint64_t
DataMemberType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
                                ::mlir::DataLayoutEntryListRef params) const {
  // FIXME: consider alignment differences under different ABIs
  assert(!MissingFeatures::cxxABI());
  return 8;
}

//===----------------------------------------------------------------------===//
//  VPtrType Definitions
//===----------------------------------------------------------------------===//

llvm::TypeSize
VPtrType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
                            mlir::DataLayoutEntryListRef params) const {
  // FIXME: consider size differences under different ABIs
  return llvm::TypeSize::getFixed(64);
}

uint64_t VPtrType::getABIAlignment(const mlir::DataLayout &dataLayout,
                                   mlir::DataLayoutEntryListRef params) const {
  // FIXME: consider alignment differences under different ABIs
  return 8;
}

//===----------------------------------------------------------------------===//
//  ArrayType Definitions
//===----------------------------------------------------------------------===//

llvm::TypeSize
ArrayType::getTypeSizeInBits(const ::mlir::DataLayout &dataLayout,
                             ::mlir::DataLayoutEntryListRef params) const {
  return getSize() * dataLayout.getTypeSizeInBits(getElementType());
}

uint64_t
ArrayType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
                           ::mlir::DataLayoutEntryListRef params) const {
  return dataLayout.getTypeABIAlignment(getElementType());
}

//===----------------------------------------------------------------------===//
// VectorType Definitions
//===----------------------------------------------------------------------===//

llvm::TypeSize cir::VectorType::getTypeSizeInBits(
    const ::mlir::DataLayout &dataLayout,
    ::mlir::DataLayoutEntryListRef params) const {
  return llvm::TypeSize::getFixed(
      getSize() * dataLayout.getTypeSizeInBits(getElementType()));
}

uint64_t
cir::VectorType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
                                 ::mlir::DataLayoutEntryListRef params) const {
  return llvm::NextPowerOf2(dataLayout.getTypeSizeInBits(*this));
}

mlir::LogicalResult cir::VectorType::verify(
    llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
    mlir::Type elementType, uint64_t size, bool scalable) {
  if (size == 0)
    return emitError() << "the number of vector elements must be non-zero";
  return success();
}

mlir::Type cir::VectorType::parse(::mlir::AsmParser &odsParser) {

  llvm::SMLoc odsLoc = odsParser.getCurrentLocation();
  mlir::Builder odsBuilder(odsParser.getContext());
  mlir::FailureOr<::mlir::Type> elementType;
  mlir::FailureOr<uint64_t> size;
  bool isScalabe = false;

  // Parse literal '<'
  if (odsParser.parseLess())
    return {};

  // Parse literal '[', if present, and set the scalability flag accordingly
  if (odsParser.parseOptionalLSquare().succeeded())
    isScalabe = true;

  // Parse variable 'size'
  size = mlir::FieldParser<uint64_t>::parse(odsParser);
  if (mlir::failed(size)) {
    odsParser.emitError(odsParser.getCurrentLocation(),
                        "failed to parse CIR_VectorType parameter 'size' which "
                        "is to be a `uint64_t`");
    return {};
  }

  // Parse literal ']', which is expected when dealing with scalable
  // dim sizes
  if (isScalabe && odsParser.parseRSquare().failed()) {
    odsParser.emitError(odsParser.getCurrentLocation(),
                        "missing closing `]` for scalable dim size");
    return {};
  }

  // Parse literal 'x'
  if (odsParser.parseKeyword("x"))
    return {};

  // Parse variable 'elementType'
  elementType = mlir::FieldParser<::mlir::Type>::parse(odsParser);
  if (mlir::failed(elementType)) {
    odsParser.emitError(odsParser.getCurrentLocation(),
                        "failed to parse CIR_VectorType parameter "
                        "'elementType' which is to be a `mlir::Type`");
    return {};
  }

  // Parse literal '>'
  if (odsParser.parseGreater())
    return {};
  return odsParser.getChecked<VectorType>(odsLoc, odsParser.getContext(),
                                          mlir::Type((*elementType)),
                                          uint64_t((*size)), isScalabe);
}

void cir::VectorType::print(mlir::AsmPrinter &odsPrinter) const {
  mlir::Builder odsBuilder(getContext());
  odsPrinter << "<";
  if (this->getIsScalable())
    odsPrinter << "[";

  odsPrinter.printStrippedAttrOrType(getSize());
  if (this->getIsScalable())
    odsPrinter << "]";
  odsPrinter << ' ' << "x";
  odsPrinter << ' ';
  odsPrinter.printStrippedAttrOrType(getElementType());
  odsPrinter << ">";
}

//===----------------------------------------------------------------------===//
// TargetAddressSpace definitions
//===----------------------------------------------------------------------===//

cir::TargetAddressSpaceAttr
cir::toCIRTargetAddressSpace(mlir::MLIRContext &context, clang::LangAS langAS) {
  return cir::TargetAddressSpaceAttr::get(
      &context,
      IntegerAttr::get(&context,
                       llvm::APSInt(clang::toTargetAddressSpace(langAS))));
}

bool cir::isMatchingAddressSpace(cir::TargetAddressSpaceAttr cirAS,
                                 clang::LangAS as) {
  // If there is no CIR target attr, consider it "default" and only match
  // when the AST address space is LangAS::Default.
  if (!cirAS)
    return as == clang::LangAS::Default;

  if (!isTargetAddressSpace(as))
    return false;

  return cirAS.getValue().getUInt() == toTargetAddressSpace(as);
}

mlir::ParseResult parseTargetAddressSpace(mlir::AsmParser &p,
                                          cir::TargetAddressSpaceAttr &attr) {
  if (failed(p.parseKeyword("target_address_space")))
    return mlir::failure();

  if (failed(p.parseLParen()))
    return mlir::failure();

  int32_t targetValue;
  if (failed(p.parseInteger(targetValue)))
    return p.emitError(p.getCurrentLocation(),
                       "expected integer address space value");

  if (failed(p.parseRParen()))
    return p.emitError(p.getCurrentLocation(),
                       "expected ')' after address space value");

  mlir::MLIRContext *context = p.getBuilder().getContext();
  attr = cir::TargetAddressSpaceAttr::get(
      context, p.getBuilder().getUI32IntegerAttr(targetValue));
  return mlir::success();
}

// The custom printer for the `addrspace` parameter in `!cir.ptr`.
// in the format of `target_address_space(N)`.
void printTargetAddressSpace(mlir::AsmPrinter &p,
                             cir::TargetAddressSpaceAttr attr) {
  p << "target_address_space(" << attr.getValue().getUInt() << ")";
}

//===----------------------------------------------------------------------===//
// CIR Dialect
//===----------------------------------------------------------------------===//

void CIRDialect::registerTypes() {
  // Register tablegen'd types.
  addTypes<
#define GET_TYPEDEF_LIST
#include "clang/CIR/Dialect/IR/CIROpsTypes.cpp.inc"
      >();

  // Register raw C++ types.
  // TODO(CIR) addTypes<RecordType>();
}
