use std::mem::MaybeUninit;

use util::{
    math::{Rect2f, Vec2},
    rc::{Arc, UniqueArc},
};

use super::PixelFormat;
use crate::color::{Premultiply, BGRA8};

mod blit;
pub(super) mod blur;
use blur::gaussian_sigma_to_box_radius;
mod strip;
pub use strip::*;

trait DrawPixel: Copy + Sized {
    fn put(&mut self, value: Self);
    fn scale_alpha(self, scale: u8) -> Self;
    const PIXEL_FORMAT: PixelFormat;
    fn cast_target_buffer<'a>(buffer: RenderTargetBufferMut<'a>) -> Option<&'a mut [Self]>;
}

impl DrawPixel for BGRA8 {
    fn put(&mut self, value: Self) {
        // TODO: blend_over
        *self = value.premultiply().0;
    }

    fn scale_alpha(self, scale: u8) -> Self {
        self.mul_alpha(scale)
    }

    const PIXEL_FORMAT: PixelFormat = PixelFormat::Bgra;
    fn cast_target_buffer<'a>(buffer: RenderTargetBufferMut<'a>) -> Option<&'a mut [Self]> {
        match buffer {
            RenderTargetBufferMut::Bgra(bgra) => Some(bgra),
            _ => None,
        }
    }
}

impl DrawPixel for u8 {
    fn put(&mut self, value: Self) {
        // Use simple additive blending for monochrome rendering
        // TODO: It's kinda weird to be using these different blending modes for
        //       unannotated primitives like `u8` though, maybe it could be cleaned up?
        *self = self.saturating_add(value);
    }

    fn scale_alpha(self, scale: u8) -> Self {
        crate::color::mul_rgb(self, scale)
    }

    const PIXEL_FORMAT: PixelFormat = PixelFormat::Mono;
    fn cast_target_buffer<'a>(buffer: RenderTargetBufferMut<'a>) -> Option<&'a mut [Self]> {
        match buffer {
            RenderTargetBufferMut::Mono(mono) => Some(mono),
            _ => None,
        }
    }
}

unsafe fn horizontal_line_unchecked<P: DrawPixel>(
    x0: i32,
    x1: i32,
    offset_buffer: &mut [P],
    width: i32,
    color: P,
) {
    for x in x0.clamp(0, width)..x1.clamp(0, width) {
        offset_buffer.get_unchecked_mut(x as usize).put(color);
    }
}

unsafe fn vertical_line_unchecked<P: DrawPixel>(
    y0: i32,
    y1: i32,
    offset_buffer: &mut [P],
    height: i32,
    stride: i32,
    color: P,
) {
    for y in y0.clamp(0, height)..y1.clamp(0, height) {
        offset_buffer
            .get_unchecked_mut((y * stride) as usize)
            .put(color);
    }
}

macro_rules! check_buffer {
    ($what: literal, $buffer: ident, $stride: ident, $height: ident) => {
        if $buffer.len() < $stride as usize * $height as usize {
            panic!(concat!(
                "Buffer passed to rasterize::",
                $what,
                " is too small"
            ))
        }
    };
}

// Scuffed Anti-Aliasing™ (SAA)
fn fill_axis_aligned_antialias_rect<P: DrawPixel>(
    x0: f32,
    y0: f32,
    x1: f32,
    y1: f32,
    buffer: &mut [P],
    width: u32,
    height: u32,
    stride: u32,
    color: P,
) {
    check_buffer!("fill_axis_aligned_antialias_rect", buffer, stride, height);

    debug_assert!(x0 <= x1);
    debug_assert!(y0 <= y1);

    const AA_THRESHOLD: f32 = 1. / 256.;

    let (left_aa, full_left) = if (x0 - x0.round()).abs() > AA_THRESHOLD {
        (true, x0.ceil() as i32)
    } else {
        (false, x0.round() as i32)
    };

    let (right_aa, full_right) = if (x1 - x1.round()).abs() > AA_THRESHOLD {
        (true, x1.floor() as i32)
    } else {
        (false, x1.round() as i32)
    };

    let (top_aa_width, full_top) = if (y0 - y0.round()).abs() > AA_THRESHOLD {
        let top_width = 1.0 - y0.fract();
        let top_fill = top_width * 255.;
        let top_y = y0.floor() as i32;
        if top_y >= 0 && top_y < height as i32 {
            unsafe {
                horizontal_line_unchecked(
                    full_left,
                    full_right,
                    &mut buffer[top_y as usize * stride as usize..],
                    width as i32,
                    color.scale_alpha(top_fill as u8),
                );
            }
        }
        (top_width, top_y + 1)
    } else {
        (1.0, y0.round() as i32)
    };

    let (bottom_aa_width, full_bottom) = if (y1 - y1.round()).abs() > AA_THRESHOLD {
        let bottom_width = y1.fract();
        let bottom_fill = bottom_width * 255.;
        let bottom_y = y1.floor() as i32;
        if bottom_y >= 0 && bottom_y < height as i32 {
            unsafe {
                horizontal_line_unchecked(
                    full_left,
                    full_right,
                    &mut buffer[bottom_y as usize * stride as usize..],
                    width as i32,
                    color.scale_alpha(bottom_fill as u8),
                );
            }
        }
        (bottom_width, bottom_y)
    } else {
        (1.0, y1.round() as i32)
    };

    if left_aa {
        let left_fill = (1.0 - x0.fract()) * 255.;
        let left_x = full_left - 1;
        if left_x >= 0 && left_x < width as i32 {
            if top_aa_width < 1.0 && full_top > 0 && full_top < height as i32 {
                buffer[(full_top - 1) as usize * stride as usize + left_x as usize]
                    .put(color.scale_alpha((left_fill * top_aa_width) as u8));
            }

            unsafe {
                vertical_line_unchecked(
                    full_top,
                    full_bottom,
                    &mut buffer[left_x as usize..],
                    height as i32,
                    stride as i32,
                    color.scale_alpha(left_fill as u8),
                );
            }

            if bottom_aa_width < 1.0 && full_bottom >= 0 && full_bottom < height as i32 {
                buffer[full_bottom as usize * stride as usize + left_x as usize]
                    .put(color.scale_alpha((left_fill * bottom_aa_width) as u8));
            }
        }
    }

    if right_aa {
        let right_fill = x1.fract() * 255.;
        let right_x = full_right;
        if right_x >= 0 && right_x < width as i32 {
            if top_aa_width < 1.0 && full_top > 0 && full_top < height as i32 {
                buffer[(full_top - 1) as usize * stride as usize + right_x as usize]
                    .put(color.scale_alpha((right_fill * top_aa_width) as u8));
            }

            unsafe {
                vertical_line_unchecked(
                    full_top,
                    full_bottom,
                    &mut buffer[right_x as usize..],
                    height as i32,
                    stride as i32,
                    color.scale_alpha(right_fill as u8),
                );
            }

            if bottom_aa_width < 1.0 && full_bottom >= 0 && full_bottom < height as i32 {
                buffer[full_bottom as usize * stride as usize + right_x as usize]
                    .put(color.scale_alpha((right_fill * bottom_aa_width) as u8));
            }
        }
    }

    for y in full_top.clamp(0, height as i32)..full_bottom.clamp(0, height as i32) {
        unsafe {
            horizontal_line_unchecked(
                full_left,
                full_right,
                &mut buffer[y as usize * stride as usize..],
                width as i32,
                color,
            );
        }
    }
}

pub(super) struct RenderTargetImpl<'a> {
    buffer: RenderTargetBuffer<'a>,
    pub width: u32,
    pub height: u32,
    pub stride: u32,
}

enum RenderTargetBuffer<'a> {
    OwnedMono(UniqueArc<[u8]>),
    BorrowedBgra(&'a mut [BGRA8]),
    BorrowedMono(&'a mut [u8]),
}

impl RenderTargetBuffer<'_> {
    fn pixel_format(&self) -> PixelFormat {
        match self {
            Self::BorrowedBgra(_) => PixelFormat::Bgra,
            Self::OwnedMono(_) | Self::BorrowedMono(_) => PixelFormat::Mono,
        }
    }
}

enum RenderTargetBufferMut<'a> {
    Bgra(&'a mut [BGRA8]),
    Mono(&'a mut [u8]),
}

pub fn create_render_target(
    buffer: &mut [BGRA8],
    width: u32,
    height: u32,
    stride: u32,
) -> super::RenderTarget<'_> {
    assert!(
        buffer.len() >= height as usize * stride as usize,
        "Buffer passed to rasterize::sw::create_render_target is too small!"
    );
    super::RenderTarget(super::RenderTargetInner::Software(RenderTargetImpl {
        buffer: RenderTargetBuffer::BorrowedBgra(buffer),
        width,
        height,
        stride,
    }))
}

pub fn create_render_target_mono(
    buffer: &mut [u8],
    width: u32,
    height: u32,
    stride: u32,
) -> super::RenderTarget<'_> {
    assert!(
        buffer.len() >= height as usize * stride as usize,
        "Buffer passed to rasterize::sw::create_render_target is too small!"
    );
    super::RenderTarget(super::RenderTargetInner::Software(RenderTargetImpl {
        buffer: RenderTargetBuffer::BorrowedMono(buffer),
        width,
        height,
        stride,
    }))
}

fn unwrap_sw_render_target<'a, 'b>(
    target: &'a mut super::RenderTarget<'b>,
) -> &'a mut RenderTargetImpl<'b> {
    #[cfg_attr(not(feature = "wgpu"), expect(unreachable_patterns))]
    match &mut target.0 {
        super::RenderTargetInner::Software(target) => target,
        target => panic!(
            "Incompatible render target {:?} passed to software rasterizer (expected: software)",
            target.variant_name()
        ),
    }
}

impl RenderTargetBuffer<'_> {
    fn as_mut(&mut self) -> RenderTargetBufferMut<'_> {
        match self {
            RenderTargetBuffer::OwnedMono(mono) => RenderTargetBufferMut::Mono(mono),
            RenderTargetBuffer::BorrowedBgra(bgra) => RenderTargetBufferMut::Bgra(bgra),
            RenderTargetBuffer::BorrowedMono(mono) => RenderTargetBufferMut::Mono(mono),
        }
    }

    fn unwrap_for<P: DrawPixel>(&mut self) -> &'_ mut [P] {
        // TODO: NLL problem case 3
        let pixel_format = self.pixel_format();
        P::cast_target_buffer( self.as_mut()).unwrap_or_else(|| {
            panic!("Incompatible render target format {:?} passed to software rasterizer (expected: {:?})", pixel_format, P::PIXEL_FORMAT)
        })
    }
}

#[derive(Clone)]
pub(super) enum TextureData {
    OwnedMono(Arc<[u8]>),
    OwnedBgra(Arc<[BGRA8]>),
}

#[derive(Clone)]
pub(super) struct TextureImpl {
    pub width: u32,
    pub height: u32,
    pub data: TextureData,
}

impl TextureImpl {
    pub(super) fn memory_footprint(&self) -> usize {
        match &self.data {
            TextureData::OwnedMono(mono) => mono.len(),
            TextureData::OwnedBgra(bgra) => bgra.len() * 4,
        }
    }

    pub(super) fn is_mono(&self) -> bool {
        match &self.data {
            TextureData::OwnedMono(_) => true,
            TextureData::OwnedBgra(_) => false,
        }
    }
}

enum UnwrappedTextureData<'a> {
    Mono(&'a [u8]),
    Bgra(&'a [BGRA8]),
}

struct UnwrappedTexture<'a> {
    width: u32,
    height: u32,
    data: UnwrappedTextureData<'a>,
}

fn unwrap_sw_texture(texture: &super::Texture) -> UnwrappedTexture<'_> {
    #[cfg_attr(not(feature = "wgpu"), expect(unreachable_patterns))]
    match &texture.0 {
        super::TextureInner::Software(texture) => UnwrappedTexture {
            width: texture.width,
            height: texture.height,
            data: match &texture.data {
                TextureData::OwnedMono(mono) => UnwrappedTextureData::Mono(mono),
                TextureData::OwnedBgra(bgra) => UnwrappedTextureData::Bgra(bgra),
            },
        },
        target => panic!(
            "Incompatible texture {:?} passed to software rasterizer",
            target.variant_name()
        ),
    }
}

pub struct Rasterizer {
    blurer: blur::Blurer,
}

impl Rasterizer {
    pub fn new() -> Self {
        Self {
            blurer: blur::Blurer::new(),
        }
    }
}

impl super::Rasterizer for Rasterizer {
    fn name(&self) -> &'static str {
        "software"
    }

    unsafe fn create_texture_mapped(
        &mut self,
        width: u32,
        height: u32,
        format: PixelFormat,
        callback: Box<dyn FnOnce(&mut [MaybeUninit<u8>], usize) + '_>,
    ) -> super::Texture {
        let n_pixels = width as usize * height as usize;
        match format {
            PixelFormat::Mono => {
                let mut data = UniqueArc::new_uninit_slice(n_pixels);

                callback(&mut data, width as usize);
                let init = UniqueArc::assume_init(data);

                super::Texture(super::TextureInner::Software(TextureImpl {
                    width,
                    height,
                    data: TextureData::OwnedMono(UniqueArc::into_shared(init)),
                }))
            }
            PixelFormat::Bgra => {
                let mut data = UniqueArc::<[BGRA8]>::new_uninit_slice(n_pixels);
                let slice = unsafe {
                    std::slice::from_raw_parts_mut(
                        data.as_mut_ptr().cast::<MaybeUninit<u8>>(),
                        data.len() * 4,
                    )
                };

                callback(slice, width as usize * 4);
                let init = UniqueArc::assume_init(data);

                super::Texture(super::TextureInner::Software(TextureImpl {
                    width,
                    height,
                    data: TextureData::OwnedBgra(UniqueArc::into_shared(init)),
                }))
            }
        }
    }

    fn create_mono_texture_rendered(
        &mut self,
        width: u32,
        height: u32,
    ) -> super::RenderTarget<'static> {
        super::RenderTarget(super::RenderTargetInner::Software(RenderTargetImpl {
            buffer: {
                let mut uninit = UniqueArc::new_uninit_slice(width as usize * height as usize);
                unsafe {
                    uninit.fill(MaybeUninit::zeroed());
                    RenderTargetBuffer::OwnedMono(UniqueArc::assume_init(uninit))
                }
            },
            width,
            height,
            stride: width,
        }))
    }

    fn finalize_texture_render(&mut self, target: super::RenderTarget<'static>) -> super::Texture {
        #[cfg_attr(not(feature = "wgpu"), expect(unreachable_patterns))]
        match target.0 {
            super::RenderTargetInner::Software(RenderTargetImpl {
                buffer,
                width,
                height,
                stride,
            }) => {
                assert_eq!(stride, width);
                super::Texture(super::TextureInner::Software(TextureImpl {
                    width,
                    height,
                    data: match buffer {
                        RenderTargetBuffer::OwnedMono(mono) => {
                            TextureData::OwnedMono(UniqueArc::into_shared(mono))
                        }
                        _ => panic!(
                            "Borrowed render target passed to software Rasterizer::finalize_texture_render"
                        ),
                    },
                }))
            }
            target => panic!(
                "Incompatible target {:?} passed to software Rasterizer::finalize_texture_render (expected: software)",
                target.variant_name()
            ),
        }
    }

    fn horizontal_line(
        &mut self,
        target: &mut super::RenderTarget,
        y: f32,
        x0: f32,
        x1: f32,
        color: BGRA8,
    ) {
        let target = unwrap_sw_render_target(target);
        let y = y as i32;

        if y < 0 || y >= target.height as i32 {
            return;
        }

        unsafe {
            horizontal_line_unchecked(
                x0 as i32,
                x1 as i32,
                &mut target.buffer.unwrap_for::<BGRA8>()[y as usize * target.stride as usize..],
                target.width as i32,
                color,
            )
        }
    }

    fn fill_axis_aligned_rect(
        &mut self,
        target: &mut super::RenderTarget,
        rect: Rect2f,
        color: BGRA8,
    ) {
        if rect.is_empty() {
            return;
        }

        let target = unwrap_sw_render_target(target);

        fill_axis_aligned_antialias_rect(
            rect.min.x,
            rect.min.y,
            rect.max.x,
            rect.max.y,
            target.buffer.unwrap_for::<BGRA8>(),
            target.width,
            target.height,
            target.stride,
            color,
        );
    }

    fn blit(
        &mut self,
        target: &mut super::RenderTarget,
        dx: i32,
        dy: i32,
        texture: &super::Texture,
        color: BGRA8,
    ) {
        let target = unwrap_sw_render_target(target);
        let texture = unwrap_sw_texture(texture);

        match texture.data {
            UnwrappedTextureData::Mono(source) => {
                blit::blit_mono(
                    target.buffer.unwrap_for::<BGRA8>(),
                    target.stride as usize,
                    target.width as usize,
                    target.height as usize,
                    source,
                    texture.width as usize,
                    texture.width as usize,
                    texture.height as usize,
                    dx as isize,
                    dy as isize,
                    color,
                );
            }
            UnwrappedTextureData::Bgra(source) => {
                blit::blit_bgra(
                    target.buffer.unwrap_for::<BGRA8>(),
                    target.stride as usize,
                    target.width as usize,
                    target.height as usize,
                    source,
                    texture.width as usize,
                    texture.width as usize,
                    texture.height as usize,
                    dx as isize,
                    dy as isize,
                    color.a,
                );
            }
        }
    }

    fn blit_to_mono_texture(
        &mut self,
        target: &mut super::RenderTarget,
        dx: i32,
        dy: i32,
        texture: &super::Texture,
    ) {
        let target = unwrap_sw_render_target(target);
        let texture = unwrap_sw_texture(texture);

        match texture.data {
            UnwrappedTextureData::Mono(source) => blit::blit_mono_to_mono(
                target.buffer.unwrap_for::<u8>(),
                target.stride as usize,
                target.width as usize,
                target.height as usize,
                source,
                texture.width as usize,
                texture.width as usize,
                texture.height as usize,
                dx as isize,
                dy as isize,
            ),
            UnwrappedTextureData::Bgra(source) => blit::blit_bgra_to_mono(
                target.buffer.unwrap_for::<u8>(),
                target.stride as usize,
                target.width as usize,
                target.height as usize,
                source,
                texture.width as usize,
                texture.width as usize,
                texture.height as usize,
                dx as isize,
                dy as isize,
            ),
        }
    }

    fn blur_prepare(&mut self, width: u32, height: u32, sigma: f32) {
        self.blurer.prepare(
            width as usize,
            height as usize,
            gaussian_sigma_to_box_radius(sigma),
        );
    }

    fn blur_buffer_blit(&mut self, dx: i32, dy: i32, texture: &super::Texture) {
        let texture = unwrap_sw_texture(texture);

        let dx = dx + self.blurer.padding() as i32;
        let dy = dy + self.blurer.padding() as i32;
        let width = self.blurer.width();
        let height = self.blurer.height();
        match texture.data {
            UnwrappedTextureData::Mono(source) => blit::copy_mono_to_float(
                self.blurer.front_mut(),
                width,
                width,
                height,
                source,
                texture.width as usize,
                texture.width as usize,
                texture.height as usize,
                dx as isize,
                dy as isize,
            ),
            UnwrappedTextureData::Bgra(source) => blit::copy_bgra_to_float(
                self.blurer.front_mut(),
                width,
                width,
                height,
                source,
                texture.width as usize,
                texture.width as usize,
                texture.height as usize,
                dx as isize,
                dy as isize,
            ),
        }
    }

    fn blur_padding(&mut self) -> Vec2<u32> {
        Vec2::splat(self.blurer.padding() as u32)
    }

    // PERF: Evaluate whether storing an f32 texture would be better
    //       or maybe make the last box_blur_vertical blur directly
    //       into a u8 buffer to avoid the floats and copy entirely
    fn blur_to_mono_texture(&mut self) -> super::Texture {
        self.blurer.box_blur_horizontal();
        self.blurer.box_blur_horizontal();
        self.blurer.box_blur_horizontal();
        self.blurer.box_blur_vertical();
        self.blurer.box_blur_vertical();
        self.blurer.box_blur_vertical();

        let mut target = self
            .create_mono_texture_rendered(self.blurer.width() as u32, self.blurer.height() as u32);

        {
            let target = unwrap_sw_render_target(&mut target);

            blit::copy_float_to_mono(
                target.buffer.unwrap_for::<u8>(),
                target.stride as usize,
                target.width as usize,
                target.height as usize,
                self.blurer.front(),
                self.blurer.width(),
                self.blurer.width(),
                self.blurer.height(),
                0,
                0,
            );
        }

        self.finalize_texture_render(target)
    }
}
