// SPDX-FileCopyrightText: 2025 The wlroots contributors
// SPDX-FileCopyrightText: 2025 KylinSoft Co., Ltd.
//
// SPDX-License-Identifier: Expat

#include <assert.h>
#include <drm_fourcc.h>
#include <stdlib.h>
#include <xf86drm.h>

#include <wlr/render/allocator.h>
#include <wlr/render/pixman.h>
#include <wlr/render/swapchain.h>

#include <kywc/log.h>

#include "drm_p.h"
#include "render/opengl.h"
#include "render/renderer.h"
#include "shaders/tex_frag_str.h"
#include "shaders/tex_vert_str.h"

static bool set_renderer_drm_formats(struct drm_renderer *renderer,
                                     const struct wlr_drm_format_set *set)
{
    if (!renderer || !set) {
        return false;
    }

    for (size_t idx = 0; idx < set->len; idx++) {
        struct wlr_drm_format *fmt = &set->formats[idx];
        wlr_drm_format_set_add(&renderer->formats, fmt->format, DRM_FORMAT_MOD_LINEAR);
    }

    wlr_drm_format_set_add(&renderer->formats, DRM_FORMAT_ARGB8888, DRM_FORMAT_MOD_LINEAR);
    return true;
}

static bool drm_mgpu_renderer_with_pixman(struct drm_device *drm, struct drm_renderer *renderer)
{
    renderer->wlr_rend = wlr_pixman_renderer_create();
    if (!renderer) {
        kywc_log(KYWC_ERROR, "Failed to create a pixman renderer for mGPU");
        return false;
    }

    renderer->allocator = wlr_allocator_autocreate(drm->wlr_backend, renderer->wlr_rend);
    if (!renderer->allocator) {
        kywc_log(KYWC_ERROR, "Failed to create allocator for renderer");
        return false;
    }

    for (size_t i = 0; i < drm->num_planes; i++) {
        struct drm_plane *drm_plane = &drm->planes[i];
        set_renderer_drm_formats(renderer, &drm_plane->formats);
    }

    return true;
}

bool drm_mgpu_renderer_init(struct drm_device *drm, struct drm_renderer *renderer)
{
    kywc_log(KYWC_INFO, "DRM mgpu render init");

    renderer->wlr_rend = ky_opengl_renderer_create_with_drm_fd(drm->fd);
    if (!renderer->wlr_rend) {
        return drm_mgpu_renderer_with_pixman(drm, renderer);
    }

    const struct wlr_drm_format_set *texture_formats = NULL;
    texture_formats = wlr_renderer_get_texture_formats(renderer->wlr_rend, WLR_BUFFER_CAP_DMABUF);
    if (!texture_formats) {
        kywc_log(KYWC_WARN, "Failed to query renderer texture formats");
        wlr_renderer_destroy(renderer->wlr_rend);
        return drm_mgpu_renderer_with_pixman(drm, renderer);
    }

    renderer->allocator = wlr_allocator_autocreate(drm->wlr_backend, renderer->wlr_rend);
    for (size_t i = 0; i < texture_formats->len; i++) {
        const struct wlr_drm_format *fmt = &texture_formats->formats[i];
        for (size_t j = 0; j < fmt->len; j++) {
            uint64_t mod = fmt->modifiers[j];
            if (mod == DRM_FORMAT_MOD_INVALID) {
                continue;
            }
            wlr_drm_format_set_add(&renderer->formats, fmt->format, mod);
        }
    }

    return true;
}

static void surface_swapchain_destroy(struct swapchain *swapchain)
{
    if (!swapchain) {
        return;
    }

    wlr_swapchain_destroy(swapchain->wlr_swapchain);
    free(swapchain);
}

void drm_mgpu_renderer_finish(struct drm_renderer *renderer)
{
    if (!renderer) {
        return;
    }

    wlr_allocator_destroy(renderer->allocator);
    wlr_renderer_destroy(renderer->wlr_rend);
    wlr_drm_format_set_finish(&renderer->formats);
}

void drm_surface_finish(struct drm_surface *surf)
{
    if (!surf || !surf->wlr_rend) {
        return;
    }

    if (surf->dumb_swapchain != surf->swapchain) {
        surface_swapchain_destroy(surf->swapchain);
    }

    surface_swapchain_destroy(surf->dumb_swapchain);

    *surf = (struct drm_surface){ 0 };
}

static bool drm_format_intersect(struct wlr_drm_format *dst, const struct wlr_drm_format *a,
                                 const struct wlr_drm_format *b)
{
    assert(a->format == b->format);

    size_t capacity = a->len < b->len ? a->len : b->len;
    uint64_t *modifiers = malloc(sizeof(*modifiers) * capacity);
    if (!modifiers) {
        return false;
    }

    struct wlr_drm_format fmt = {
        .capacity = capacity,
        .len = 0,
        .modifiers = modifiers,
        .format = a->format,
    };

    for (size_t i = 0; i < a->len; i++) {
        for (size_t j = 0; j < b->len; j++) {
            if (a->modifiers[i] == b->modifiers[j]) {
                assert(fmt.len < fmt.capacity);
                fmt.modifiers[fmt.len++] = a->modifiers[i];
                break;
            }
        }
    }

    wlr_drm_format_finish(dst);
    *dst = fmt;

    return true;
}

static bool create_opengl_shader(struct ky_opengl_renderer *gl_renderer)
{
    struct ky_opengl_surface_shader *gl_shader = &gl_renderer->shaders.surface_tex;
    if (gl_shader->program) {
        return true;
    }

    GLuint prog = ky_opengl_create_program(gl_renderer, tex_vert_str, tex_frag_str);
    if (prog == 0) {
        return false;
    }

    gl_shader->program = prog;
    gl_shader->in_position = glGetAttribLocation(prog, "in_position");
    gl_shader->in_texcoord = glGetAttribLocation(prog, "in_texcoord");
    gl_shader->contrast = glGetUniformLocation(prog, "contrast");
    gl_shader->whitepoint = glGetUniformLocation(prog, "whitepoint");
    gl_shader->color_matrix = glGetUniformLocation(prog, "color_matrix");
    gl_shader->brightness = glGetUniformLocation(prog, "brightness");

    return true;
}

static bool render_pass_add_texture(struct wlr_renderer *renderer, struct wlr_texture *src_texture,
                                    struct drm_render_target *target)
{
    if (target->damage && !pixman_region32_not_empty(target->damage)) {
        return true;
    }

    struct ky_opengl_renderer *gl_renderer = ky_opengl_renderer_from_wlr_renderer(renderer);
    if (!create_opengl_shader(gl_renderer)) {
        return false;
    }

    struct ky_opengl_surface_shader *gl_shader = &gl_renderer->shaders.surface_tex;
    glUseProgram(gl_shader->program);
    glUniform1i(glGetUniformLocation(gl_shader->program, "tex"), 0);

    int width = src_texture->width;
    int height = src_texture->height;
    pixman_region32_t clipped;
    pixman_region32_init_rect(&clipped, 0, 0, width, height);

    if (target->damage) {
        pixman_region32_intersect_rect(&clipped, target->damage, 0, 0, width, height);
    }

    int rects_len;
    const pixman_box32_t *rects = pixman_region32_rectangles(&clipped, &rects_len);
    float *vertices = calloc(1, rects_len * 6 * 4 * sizeof(float));

    int idx = 0;
    for (int i = 0; i < rects_len; i++) {
        int x1 = rects[i].x1, y1 = rects[i].y1;
        int x2 = rects[i].x2, y2 = rects[i].y2;

        // NDC
        float left = (float)x1 / width * 2.0f - 1.0f;
        float right = (float)x2 / width * 2.0f - 1.0f;

        float top = 1.0 - (float)(height - y2) / height * 2.0f;
        float bottom = 1.0 - (float)(height - y1) / height * 2.0f;

        // UV
        float u1 = x1 / (float)width;

        float u2 = x2 / (float)width;

        float v1 = (y1 / (float)height);
        float v2 = (y2 / (float)height);

        // triangles1
        vertices[idx++] = left, vertices[idx++] = bottom;
        vertices[idx++] = u1, vertices[idx++] = v1;

        vertices[idx++] = right, vertices[idx++] = bottom;
        vertices[idx++] = u2, vertices[idx++] = v1;

        vertices[idx++] = right, vertices[idx++] = top;
        vertices[idx++] = u2, vertices[idx++] = v2;

        // triangles2
        vertices[idx++] = right, vertices[idx++] = top;
        vertices[idx++] = u2, vertices[idx++] = v2;

        vertices[idx++] = left, vertices[idx++] = top;
        vertices[idx++] = u1, vertices[idx++] = v2;

        vertices[idx++] = left, vertices[idx++] = bottom;
        vertices[idx++] = u1, vertices[idx++] = v1;
    }

    pixman_region32_fini(&clipped);

    glEnableVertexAttribArray(gl_shader->in_position);
    glVertexAttribPointer(gl_shader->in_position, 2, GL_FLOAT, GL_FALSE, 4 * sizeof(float),
                          (void *)vertices); // 位置
    glEnableVertexAttribArray(gl_shader->in_texcoord);
    glVertexAttribPointer(gl_shader->in_texcoord, 2, GL_FLOAT, GL_FALSE, 4 * sizeof(float),
                          (void *)(vertices + 2));

    float constrast = 1.0;
    glUniform1f(gl_shader->contrast, constrast);

    float whitepoint[3] = { 0 };
    drm_color_temp_to_rgb(target->color_temp, whitepoint);
    glUniform3fv(gl_shader->whitepoint, 1, whitepoint);

    glUniformMatrix4fv(gl_shader->color_matrix, 1, GL_FALSE, target->color_mat);
    glUniform1f(gl_shader->brightness, (float)target->brightness * 0.01);

    // bind texture
    glActiveTexture(GL_TEXTURE0);

    struct ky_opengl_texture *texture = ky_opengl_texture_from_wlr_texture(src_texture);
    glBindTexture(GL_TEXTURE_2D, texture->tex);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR);
    // draw damage
    glDrawArrays(GL_TRIANGLES, 0, rects_len * 6);

    glDisableVertexAttribArray(gl_shader->in_position);
    glDisableVertexAttribArray(gl_shader->in_texcoord);
    glBindTexture(GL_TEXTURE_2D, 0);
    glBindBuffer(GL_ARRAY_BUFFER, 0);

    free(vertices);

    return true;
}

static bool readpixel_from_buffer(struct wlr_renderer *wlr_rend, struct wlr_buffer *source,
                                  struct wlr_buffer *dst, const pixman_region32_t *damage)
{

    if (damage && !pixman_region32_not_empty(damage)) {
        return true;
    }

    size_t stride;
    uint32_t format = 0;
    void *dst_ptr = NULL;
    if (!wlr_buffer_begin_data_ptr_access(dst, WLR_BUFFER_DATA_PTR_ACCESS_WRITE, &dst_ptr, &format,
                                          &stride)) {
        kywc_log(KYWC_ERROR, "Failed to get buffer data %p", dst);
        return false;
    }

    struct wlr_dmabuf_attributes attribs;
    if (!wlr_buffer_get_dmabuf(source, &attribs)) {
        wlr_buffer_end_data_ptr_access(dst);
        return false;
    }

    struct wlr_texture *texture = wlr_texture_from_buffer(wlr_rend, source);
    if (!texture) {
        wlr_buffer_end_data_ptr_access(dst);
        return false;
    }

    int x = 0, y = 0;
    int width = source->width;
    int height = source->height;

    if (damage) {
        pixman_region32_t clipped;
        pixman_region32_init(&clipped);
        pixman_region32_intersect_rect(&clipped, damage, 0, 0, width, height);
        y = clipped.extents.y1;
        height = clipped.extents.y2 - clipped.extents.y1;
        pixman_region32_fini(&clipped);
    }

    struct wlr_box box = { x, y, width, height };
    wlr_texture_read_pixels(
        texture, &(struct wlr_texture_read_pixels_options){
                     .data = dst_ptr, .format = format, .stride = stride, .src_box = box });

    wlr_texture_destroy(texture);
    wlr_buffer_end_data_ptr_access(dst);

    return true;
}

static struct wlr_buffer *blit_buffer_with_opengl(struct drm_surface *surf,
                                                  struct drm_render_target *target)
{
    struct wlr_renderer *renderer = target->wlr_rend;
    struct wlr_texture *tex = wlr_texture_from_buffer(renderer, target->source);
    if (!tex) {
        kywc_log(KYWC_ERROR, "Failed to import source buffer into multi-GPU renderer");
        return NULL;
    }

    struct wlr_buffer *dst = wlr_swapchain_acquire(surf->swapchain->wlr_swapchain, NULL);
    if (!dst) {
        kywc_log(KYWC_ERROR, "Failed to acquire multi-GPU swapchain buffer");
        goto error_tex;
    }

    struct wlr_render_pass *pass = wlr_renderer_begin_buffer_pass(renderer, dst, NULL);
    if (!pass) {
        kywc_log(KYWC_ERROR, "Failed to begin render pass with multi-GPU destination buffer");
        goto error_dst;
    }

    if (target->rgb_clear) {
        struct wlr_render_rect_options option = (struct wlr_render_rect_options){
            .box = { .width = dst->width, .height = dst->height },
            .clip = target->damage,
            .blend_mode = WLR_RENDER_BLEND_MODE_NONE,
        };
        wlr_render_pass_add_rect(pass, &option);
    }

    render_pass_add_texture(renderer, tex, target);

    if (!wlr_render_pass_submit(pass)) {
        kywc_log(KYWC_ERROR, "Failed to submit multi-GPU render pass");
        goto error_dst;
    }

    wlr_texture_destroy(tex);

    return dst;

error_dst:
    wlr_buffer_unlock(dst);
error_tex:
    wlr_texture_destroy(tex);
    return NULL;
}

static bool drm_plane_pick_render_format(struct wlr_renderer *renderer,
                                         const struct wlr_drm_format_set *display_formats,
                                         struct wlr_drm_format *format, uint32_t fmt)
{
    const struct wlr_drm_format *render_format =
        ky_renderer_get_render_format(renderer, fmt, wlr_renderer_is_pixman(renderer));
    if (render_format == NULL) {
        kywc_log(KYWC_ERROR, "Failed to get render formats");
        return false;
    }

    const struct wlr_drm_format *display_format = wlr_drm_format_set_get(display_formats, fmt);
    if (display_format == NULL) {
        kywc_log(KYWC_ERROR, "Output plane doesn't support format 0x%" PRIX32, fmt);
        return false;
    }

    if (!drm_format_intersect(format, display_format, render_format)) {
        kywc_log(KYWC_ERROR,
                 "Failed to intersect display and render modifiers for format 0x%" PRIX32 "", fmt);
        return false;
    }

    if (format->len == 0) {
        wlr_drm_format_finish(format);
        kywc_log(KYWC_ERROR, "Failed to pick output format");
        return false;
    }

    return true;
}

static bool test_swapchain(struct wlr_output *output, struct wlr_swapchain *swapchain,
                           const struct wlr_output_state *state)
{
    struct wlr_buffer *buffer = wlr_swapchain_acquire(swapchain, NULL);
    if (buffer == NULL) {
        return false;
    }

    struct wlr_output_state copy = *state;
    copy.committed |= WLR_OUTPUT_STATE_BUFFER;
    copy.buffer = buffer;
    bool ok = wlr_output_test_state(output, &copy);
    wlr_buffer_unlock(buffer);
    return ok;
}

static bool drm_format_has(const struct wlr_drm_format *fmt, uint64_t modifier)
{
    for (size_t i = 0; i < fmt->len; ++i) {
        if (fmt->modifiers[i] == modifier) {
            return true;
        }
    }
    return false;
}

static bool drm_format_add(struct wlr_drm_format *fmt, uint64_t modifier)
{
    if (drm_format_has(fmt, modifier)) {
        return true;
    }

    if (fmt->len == fmt->capacity) {
        size_t capacity = fmt->capacity ? fmt->capacity * 2 : 4;

        uint64_t *new_modifiers = realloc(fmt->modifiers, sizeof(*fmt->modifiers) * capacity);
        if (!new_modifiers) {
            kywc_log(KYWC_ERROR, "Allocation failed");
            return false;
        }

        fmt->capacity = capacity;
        fmt->modifiers = new_modifiers;
    }

    fmt->modifiers[fmt->len++] = modifier;
    return true;
}

static struct swapchain *surface_swapchain_create(struct wlr_allocator *allocator,
                                                  struct swapchain_option *option)
{
    struct swapchain *swapchain = calloc(1, sizeof(*swapchain));
    if (!swapchain) {
        return NULL;
    }

    struct wlr_drm_format *render_format = option->drm_format;
    struct wlr_output *wlr_output = option->output;

    char *format_name = drmGetFormatName(render_format->format);
    kywc_log(KYWC_DEBUG, "Choosing plane buffer format %s (0x%08" PRIX32 ") for output '%s'",
             format_name ? format_name : "<unknown>", render_format->format, wlr_output->name);
    free(format_name);

    struct wlr_swapchain *wlr_swapchain =
        wlr_swapchain_create(allocator, option->width, option->height, render_format);
    if (!wlr_swapchain) {
        wlr_drm_format_finish(render_format);
        free(swapchain);
        return NULL;
    }

    if (!option->state) {
        swapchain->wlr_swapchain = wlr_swapchain;
        return swapchain;
    }

    if (!test_swapchain(wlr_output, wlr_swapchain, option->state)) {
        kywc_log(KYWC_DEBUG, "Output test failed on '%s', retrying without modifiers",
                 option->output->name);
        wlr_swapchain_destroy(wlr_swapchain);
        wlr_swapchain = NULL;

        if (render_format->len != 1 || render_format->modifiers[0] != DRM_FORMAT_MOD_LINEAR) {
            if (!drm_format_has(render_format, DRM_FORMAT_MOD_INVALID)) {
                kywc_log(KYWC_DEBUG, "Implicit modifiers not supported");
                wlr_drm_format_finish(render_format);
            }
        }

        render_format->len = 0;
        if (!drm_format_add(render_format, DRM_FORMAT_MOD_INVALID)) {
            kywc_log(KYWC_DEBUG, "Failed to add implicit modifier to format");
            goto error;
        }

        wlr_swapchain =
            wlr_swapchain_create(allocator, option->width, option->height, render_format);
        if (!test_swapchain(wlr_output, wlr_swapchain, option->state)) {
            kywc_log(KYWC_DEBUG, "Output test failed on '%s', retrying without modifiers",
                     option->output->name);
            goto error;
        }
    }

    swapchain->wlr_swapchain = wlr_swapchain;
    return swapchain;

error:
    wlr_drm_format_finish(render_format);
    wlr_swapchain_destroy(wlr_swapchain);
    free(swapchain);
    return NULL;
}

bool drm_plane_configure_surface_swapchain(struct drm_plane *plane, struct drm_connector *conn,
                                           struct drm_renderer *mgpu_renderer, int width,
                                           int height, uint32_t format,
                                           const struct wlr_output_state *state)
{
    struct drm_surface *surf = &plane->multi_surf;
    if (surf->swapchain && surf->swapchain->wlr_swapchain &&
        surf->swapchain->wlr_swapchain->width == width &&
        surf->swapchain->wlr_swapchain->height == height) {
        return true;
    }

    if (surf->dumb_swapchain != surf->swapchain) {
        surface_swapchain_destroy(surf->swapchain);
    }

    surface_swapchain_destroy(surf->dumb_swapchain);
    surf->swapchain = NULL;
    surf->dumb_swapchain = NULL;

    /* both primary and secondary gpu use pixman */
    bool both_pixman = false;
    struct wlr_drm_format render_format = { 0 };
    struct wlr_renderer *renderer = conn->output.renderer;
    struct wlr_allocator *allocator = conn->output.allocator;
    const struct wlr_drm_format_set *display_formats = &plane->formats;
    struct swapchain *swapchain = NULL;
    struct swapchain *dumb_swapchain = NULL;

    if (mgpu_renderer != NULL) {
        display_formats = &mgpu_renderer->formats;
        if (!wlr_renderer_is_pixman(mgpu_renderer->wlr_rend)) {
            renderer = mgpu_renderer->wlr_rend;
            allocator = mgpu_renderer->allocator;
        } else if (wlr_renderer_is_pixman(renderer)) {
            both_pixman = true;
            goto dumb_alloc;
        }
    } else if (wlr_renderer_is_pixman(renderer)) {
        surf->wlr_rend = renderer;
        return true;
    }

    if (!drm_plane_pick_render_format(renderer, display_formats, &render_format, format)) {
        kywc_log(KYWC_WARN, "Renderer pick renderer format failed");
        return false;
    }

    struct swapchain_option option = {
        .width = width,
        .height = height,
        .output = &conn->output,
        .drm_format = &render_format,
        .state = state,
    };

    swapchain = surface_swapchain_create(allocator, &option);
    if (!swapchain) {
        return false;
    }
    wlr_drm_format_finish(&render_format);

dumb_alloc:
    if (mgpu_renderer && wlr_renderer_is_pixman(mgpu_renderer->wlr_rend)) {
        struct wlr_drm_format drm_format = { 0 };
        uint32_t disp_format = format;
        if (!wlr_drm_format_set_get(&plane->formats, format)) {
            disp_format = DRM_FORMAT_XRGB8888;
        }
        if (!drm_plane_pick_render_format(mgpu_renderer->wlr_rend, &plane->formats, &drm_format,
                                          disp_format)) {
            kywc_log(KYWC_WARN, "surface dumb format format failed");
            surface_swapchain_destroy(swapchain);
            return false;
        }

        struct swapchain_option option = {
            .width = width,
            .height = height,
            .output = &conn->output,
            .drm_format = &drm_format,
            .state = NULL,
        };
        dumb_swapchain = surface_swapchain_create(mgpu_renderer->allocator, &option);
        if (!dumb_swapchain) {
            surface_swapchain_destroy(swapchain);
            return false;
        }
        wlr_drm_format_finish(&drm_format);

        if (both_pixman) {
            swapchain = dumb_swapchain;
        }
    }

    surf->swapchain = swapchain;
    surf->dumb_swapchain = dumb_swapchain;
    surf->wlr_rend = renderer;

    return true;
}

struct wlr_buffer *drm_surface_blit(struct drm_surface *surface, struct drm_render_target *target)
{
    if (wlr_renderer_is_pixman(surface->wlr_rend) && !surface->swapchain &&
        !surface->dumb_swapchain) {
        // do not blit
        return wlr_buffer_lock(target->source);
    }

    struct wlr_buffer *buffer = target->source;
    if (wlr_renderer_is_opengl(surface->wlr_rend)) {
        buffer = blit_buffer_with_opengl(surface, target);
    } else {
        buffer = wlr_buffer_lock(buffer);
    }

    if (!surface->dumb_swapchain) {
        return buffer;
    }

    struct wlr_buffer *dumb_buffer =
        wlr_swapchain_acquire(surface->dumb_swapchain->wlr_swapchain, NULL);
    if (!dumb_buffer) {
        kywc_log(KYWC_ERROR, "Failed to acquire multi-GPU dumb swapchain buffer");
        goto error;
    }

    if (!readpixel_from_buffer(target->wlr_rend, buffer, dumb_buffer, target->damage)) {
        goto error_pixel;
    }

    wlr_buffer_unlock(buffer);
    return dumb_buffer;

error_pixel:
    wlr_buffer_unlock(dumb_buffer);
error:
    wlr_buffer_unlock(buffer);
    return NULL;
}
