/*******************************************************************************
* Copyright 2021-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include "gpu/intel/jit/ir/gemm_schedule.hpp"

namespace dnnl {
namespace impl {
namespace gpu {
namespace intel {
namespace jit {

layout_t bmnk_mapper_t::map_to_bmnk(abc_kind_t abc_kind,
        const std::vector<bmnk_kind_t> &bmnk_kinds, const view_t &view) const {
    auto layout = view.create_pseudo_vlayout();
    return map_to_bmnk(abc_kind, bmnk_kinds, layout);
}

layout_t bmnk_mapper_t::map_to_bmnk(abc_kind_t abc_kind,
        const std::vector<bmnk_kind_t> &bmnk_kinds,
        const layout_t &layout) const {
    std::vector<layout_block_t> blocks;
    for (auto &b : layout.blocks()) {
        auto b_bmnk_kind = bmnk_kind(abc_kind, b.idx);
        bool found = false;
        for (int i = 0; i < int(bmnk_kinds.size()); i++) {
            if (bmnk_kinds[i] == b_bmnk_kind) {
                blocks.emplace_back(i, b.size, b.stride);
                found = true;
                break;
            }
        }
        if (!found) gpu_error_not_expected() << "MNK dimension not found.";
    }
    return layout_t(layout.type(), blocks, 0, int(bmnk_kinds.size()));
}

layout_t bmnk_mapper_t::map_from_bmnk(abc_kind_t abc_kind,
        const std::vector<bmnk_kind_t> &bmnk_kinds, const layout_t &bmnk_layout,
        const layout_t &abc_layout) const {
    bmnk_block_mapper_t m(*this);
    m.push_blocks(abc_kind, abc_layout.blocks());
    return m.map_from_bmnk(abc_kind, bmnk_kinds, bmnk_layout);
}

void bmnk_block_mapper_t::push_block(
        abc_kind_t abc_kind, const layout_block_t &b) {
    auto bmnk_kind = bmnk_mapper_.bmnk_kind(abc_kind, b.idx);
    switch (bmnk_kind) {
        case bmnk_kind_t::b:
            if (abc_kind == abc_kind_t::a) b_blocks_.emplace_back(abc_kind, b);
            break;
        case bmnk_kind_t::m: m_blocks_.emplace_back(abc_kind, b); break;
        case bmnk_kind_t::n: n_blocks_.emplace_back(abc_kind, b); break;
        case bmnk_kind_t::k: k_blocks_.emplace_back(abc_kind, b); break;
        default: gpu_error_not_expected() << "Unknown MNK kind.";
    }
}

layout_t bmnk_block_mapper_t::map_from_bmnk(abc_kind_t abc_kind,
        const std::vector<bmnk_kind_t> &bmnk_kinds,
        const layout_t &bmnk_layout) const {
    gpu_assert(bmnk_layout.ndims() <= 3);
    gpu_assert(is_zero(bmnk_layout.offset()));
    std::vector<layout_block_t> blocks;
    std::vector<std::vector<layout_block_t>> tmp_blocks(
            static_cast<int>(bmnk_kind_t::k) + 1);
    tmp_blocks[static_cast<int>(bmnk_kind_t::b)]
            = create_prb_blocks(abc_kind, b_blocks_);
    tmp_blocks[static_cast<int>(bmnk_kind_t::m)]
            = create_prb_blocks(abc_kind, m_blocks_);
    tmp_blocks[static_cast<int>(bmnk_kind_t::n)]
            = create_prb_blocks(abc_kind, n_blocks_);
    tmp_blocks[static_cast<int>(bmnk_kind_t::k)]
            = create_prb_blocks(abc_kind, k_blocks_);
    for (auto &b : bmnk_layout.blocks()) {
        auto &bmnk_blocks = tmp_blocks[static_cast<int>(bmnk_kinds[b.idx])];
        bool ok = pop_block(bmnk_blocks, blocks, b);
        gpu_assert(ok) << "Can't map from bmnk layout to problem layout.";
        MAYBE_UNUSED(ok);
    }
    for (auto bmnk_kind : bmnk_kinds) {
        auto &bmnk_blocks = tmp_blocks[static_cast<int>(bmnk_kind)];
        pop_size_1_blocks(bmnk_blocks);
        gpu_assert(bmnk_blocks.empty());
    }

    // Fix strides to make them dense.
    dim_t dense_stride = 1;
    for (auto &b : blocks) {
        b.stride = stride_t(dense_stride);
        dense_stride *= b.size;
    }

    return layout_t(
            bmnk_layout.type(), blocks, 0, bmnk_mapper_.ndims(abc_kind));
}

bool bmnk_block_mapper_t::pop_block(std::vector<layout_block_t> &bmnk_blocks,
        std::vector<layout_block_t> &prb_blocks,
        const layout_block_t &bmnk_block) const {
    if (bmnk_block.size == 1) return true;

    pop_size_1_blocks(bmnk_blocks);
    if (bmnk_blocks.empty()) return false;

    auto &next_block = bmnk_blocks.front();
    dim_t common_size = math::gcd(next_block.size, bmnk_block.size);
    if (common_size == bmnk_block.size) {
        prb_blocks.emplace_back(next_block.idx, common_size, next_block.stride);
        next_block.size /= common_size;
        next_block.stride *= common_size;
        return true;
    } else if (common_size == next_block.size) {
        prb_blocks.emplace_back(next_block.idx, common_size, next_block.stride);
        bmnk_blocks.erase(bmnk_blocks.begin());
        auto tmp_block = bmnk_block;
        tmp_block.size /= common_size;
        return pop_block(bmnk_blocks, prb_blocks, tmp_block);
    }
    return false;
}

} // namespace jit
} // namespace intel
} // namespace gpu
} // namespace impl
} // namespace dnnl
