//
//  File.metal
//  PPCameraView
//
//  Created by sanhue cheng on 2020/4/17.
//

#include <metal_stdlib>

using namespace metal;

#import "PPCVKernelTypes.h"

//==============================================================================
//
//==============================================================================
kernel void createMapping(texture2d<float, access::read> source[[texture(0)]],
                          texture2d<float, access::write> result[[texture(1)]],
                          constant PPCVCreateMappingParams *params[[buffer(0)]],
                          const uint2 threadInGrid[[thread_position_in_grid]])
{
    if(threadInGrid.x >= source.get_width() || threadInGrid.y >= source.get_height())
    {
        // Return early if the pixel is out of bounds
        return;
    }

    if(threadInGrid.x < (uint)params->startX || threadInGrid.x >= (uint)params->stopX ||
       threadInGrid.y < (uint)params->startY || threadInGrid.y >= (uint)params->stopY )
    {
        return;
    }


    // float4是vector，可用xyzw, rgba來存取其中的值
    float4 source_color = source.read(threadInGrid); // float的值表示是rgb/255, 值為0~1.0之間

    // RGB取最大的一個
    float maxRGB = max(source_color.r, max(source_color.g, source_color.b));
    int maxValue = maxRGB*255;
    
    // 因為存到float，為了避免誤差，所以直接用定值，讀取時要用0，0.5，1.0轉換為0，127，255
    float resultValue = 0;
    
    int diff=0;
    if( maxValue > params->medianValue )
    {
        diff = maxValue - params->medianValue;
    }
    else
    {
        diff = params->medianValue - maxValue;
    }
    
    if( diff <= 4 )
    {
        resultValue = 0.5;
    }
    else if( maxValue <= params->medianValue )
    {
        resultValue = 0;
    }
    else
    {
        resultValue = 1.0;
    }

    // 用texture返回結果
    float4 result_color = float4(resultValue, 0, 0, 1.0);
    uint2 resultGrid = uint2(threadInGrid.x-(uint)params->startX, threadInGrid.y-(uint)params->startY);
    result.write(result_color, resultGrid);
}


//==============================================================================
//
//==============================================================================
kernel void alignmentOffset(texture2d<float, access::read> base[[texture(0)]],
                            texture2d<float, access::read> compare[[texture(1)]],
                            constant PPCVCalculateOffsetParams *params[[buffer(0)]],
                            device int *resultBuffer[[buffer(1)]],
                            const uint2 threadInGrid[[thread_position_in_grid]])

{
    int x = threadInGrid.x;
    int y = threadInGrid.y;
    int step_size = params->stepSize;
    int off_x = params->offsetX;
    int off_y = params->offsetY;

    x *= step_size;
    y *= step_size;

    if( x+off_x >= step_size &&
       x+off_x < (int)compare.get_width()-step_size &&
       y+off_y >= step_size &&
       y+off_y < (int)compare.get_height()-step_size)
    {
        float rValue0 = base.read(threadInGrid).r;
        int c=0;

        for(int dy=-1;dy<=1;dy++)
        {
            for(int dx=-1;dx<=1;dx++)
            {
                float rValue1 = compare.read(uint2(x+off_x+dx*step_size, y+off_y+dy*step_size)).r;

                if( rValue0 != rValue1)
                {
                    if( rValue0 != 0.5 && rValue1 != 0.5 )
                    {
                        resultBuffer[c]++;
                    }
                }

                c++;
            }
        }
    }
}


//==============================================================================
//
//==============================================================================
static uchar4 ACES(float3 hdr)
{
    uchar4 out;
    const float a = 2.51f;
    const float b = 0.03f;
    const float c = 2.43f;
    const float d = 0.59f;
    const float e = 0.14f;
    float3 x = hdr/255.0;
    float3 out_f = 255.0f * (x*(a*x+b))/(x*(c*x+d)+e);
    out.r = (uchar)clamp(out_f.r+0.5f, 0.0f, 255.0f);
    out.g = (uchar)clamp(out_f.g+0.5f, 0.0f, 255.0f);
    out.b = (uchar)clamp(out_f.b+0.5f, 0.0f, 255.0f);
    out.a = 255;
    return out;
}


//==============================================================================
//
//==============================================================================
kernel void generateHDR(texture2d<float, access::read> source0[[texture(0)]],
                        texture2d<float, access::read> source1[[texture(1)]],
                        texture2d<float, access::read> source2[[texture(2)]],
                        texture2d<float, access::write> result[[texture(3)]],
                        constant PPCVGenerateHDRParams *params[[buffer(0)]],
                        const uint2 threadInGrid[[thread_position_in_grid]])
{
    if(threadInGrid.x >= source0.get_width() || threadInGrid.y >= source0.get_height())
    {
        // Return early if the pixel is out of bounds
        return;
    }

    int x = threadInGrid.x;
    int y = threadInGrid.y;
    int offset_x0 =params->offsetX[0];
    int offset_y0 =params->offsetY[0];
    int offset_x2 =params->offsetX[2];
    int offset_y2 =params->offsetY[2];
    float4 in = source1.read(threadInGrid);

    int32_t ix = x;
    int32_t iy = y;
    const int max_bitmaps_c = 3;
    int n_bitmaps = 3;
    const int mid_indx = (n_bitmaps-1)/2;
    float4 pixels[max_bitmaps_c];

    float parameter_A[max_bitmaps_c];
    float parameter_B[max_bitmaps_c];

    parameter_A[0] = params->parameterA[0];
    parameter_B[0] = params->parameterB[0];
    parameter_A[1] = params->parameterA[1];
    parameter_B[1] = params->parameterB[1];
    parameter_A[2] = params->parameterA[2];
    parameter_B[2] = params->parameterB[2];

    if( ix+offset_x0 >= 0 && iy+offset_y0 >= 0 && ix+offset_x0 < (int)source0.get_width() && iy+offset_y0 < (int)source0.get_height() )
    {
        pixels[0] = source0.read(uint2(x+offset_x0, y+offset_y0));
    }
    else
    {
        pixels[0] = in;
        parameter_A[0] = parameter_A[mid_indx];
        parameter_B[0] = parameter_B[mid_indx];
    }

    pixels[1] = in;

    if( ix+offset_x2 >= 0 && iy+offset_y2 >= 0 && ix+offset_x2 < (int)source2.get_width() && iy+offset_y2 < (int)source2.get_height() )
    {
        pixels[2] = source2.read(uint2(x+offset_x2, y+offset_y2));
    }
    else
    {
        pixels[2] = in;
        parameter_A[2] = parameter_A[mid_indx];
        parameter_B[2] = parameter_B[mid_indx];
    }

    float3 hdr = float3(0.0, 0.0, 0.0);
    float sum_weight = 0.0f;

    const float safe_range_c = 80.0f;
    float3 rgb = float3(pixels[mid_indx].r*255.0, pixels[mid_indx].g*255.0, pixels[mid_indx].b*255.0);
    float avg = (rgb.r+rgb.g+rgb.b) / 3.0f;
    float diff = fabs( avg - 127.5f );
    float weight = 1.0f;
    
    if( diff > safe_range_c )
    {
        weight = 1.0f - 0.99f * (diff - safe_range_c) / (127.5f - safe_range_c);
    }

    rgb = parameter_A[mid_indx] * rgb + parameter_B[mid_indx];

    hdr += weight * rgb;
    sum_weight += weight;

    if( weight < 1.0 )
    {
        float3 base_rgb = rgb;

        weight = 1.0f - weight;
        if( avg <= 127.5f )
        {
            rgb = float3(pixels[mid_indx+1].r*255.0, pixels[mid_indx+1].g*255.0, pixels[mid_indx+1].b*255.0);
            rgb = parameter_A[mid_indx+1] * rgb + parameter_B[mid_indx+1];
        }
        else
        {
            rgb = float3(pixels[mid_indx-1].r*255.0, pixels[mid_indx-1].g*255.0, pixels[mid_indx-1].b*255.0);
            rgb = parameter_A[mid_indx-1] * rgb + parameter_B[mid_indx-1];
        }

        float value = fmax(rgb.r, rgb.g);
        value = fmax(value, rgb.b);
        if( value <= 250.0f )
        {
//                const float wiener_C_lo = 2000.0f;
//                const float wiener_C_hi = 8000.0f;
            const float wiener_C_lo = 4000.0f;
            const float wiener_C_hi = 10000.0f;
            float wiener_C = wiener_C_lo;
            float x = fabs( value - 127.5f ) - safe_range_c;
            if( x > 0.0f )
            {
                const float scale = (wiener_C_hi-wiener_C_lo)/(127.5f-safe_range_c);
                wiener_C = wiener_C_lo + x*scale;
            }
            float3 diff = base_rgb - rgb;
            float L = dot(diff, diff);
            float ghost_weight = L/(L+wiener_C);
            rgb = ghost_weight * base_rgb + (1.0-ghost_weight) * rgb;
        }

        hdr += weight * rgb;
        sum_weight += weight;
    }

    hdr /= sum_weight;

    uchar4 out = ACES(hdr);
    float4 final = float4(out.r/255.0, out.g/255.0, out.b/255.0, 1.0);
    
    // 用texture返回結果
    result.write(final, threadInGrid);
}
