import { useRef, useEffect, useCallback } from 'react';
import * as d3 from 'd3';
import CanvasLayer from '../components/CanvasLayer';

const useCanvasDrawing = ({
    staticCanvasRef,
    highlightCanvasRef,
    imageRef,
    isImageLoaded,
    displacementData,
    popups,
    selectionMode,
    polygonPoints,
    currentMousePos,
    isDrawing,
    isFinalizing,
    addPolygonPoint,
    addSinglePopup,
    addMultiPopup,
    fetchTimeseriesData,
    fetchTimeseriesForMultiplePoints,
    setPolygonPoints,
    setCurrentMousePos,
    setError,
    transformState,
}) => {
    // Define the color scale for displacement markers
    const colorScale = useRef(
        d3.scaleSequential(d3.interpolateRgbBasis([
            "#00007f",
            "#0000ff",
            "#00ffff",
            "#00ff00",
            "#ffff00",
            "#ff0000",
            "#7f0000",
        ]))
    );

    const updateColorScale = useCallback((displacements) => {
        if (!displacements || displacements.length === 0) return;

        // Get all displacement values
        const values = displacements.map(d => d.displacement);
        const sortedValues = [...values].sort((a, b) => a - b);

        // Calculate 5th and 95th percentiles
        const p80Index = Math.floor(sortedValues.length * 0.8);
        const p20Index = Math.floor(sortedValues.length * 0.2);
        const p80Value = sortedValues[p80Index];
        const p20Value = sortedValues[p20Index];

        // Find the larger absolute value to create a symmetric range
        const maxAbsValue = Math.max(Math.abs(p80Value), Math.abs(p20Value));

        // Update the color scale domain to be centered around zero
        colorScale.current.domain([-15, 15]);
    }, []);

    const drawStaticMarkers = useCallback((displacements) => {
        if (!staticCanvasRef.current || !imageRef.current || !isImageLoaded) return;
        const canvas = staticCanvasRef.current;
        const ctx = canvas.getContext('2d');
        const img = imageRef.current;

        canvas.width = img.naturalWidth;
        canvas.height = img.naturalHeight;

        ctx.clearRect(0, 0, canvas.width, canvas.height);

        if (displacements.length === 0) return;

        // Update the color scale before drawing
        updateColorScale(displacements);

        ctx.save();
        ctx.globalAlpha = 1;

        displacements.forEach(d => {
            const naturalX = d.x;
            const naturalY = d.y;

            if (naturalX < 0 || naturalX > img.naturalWidth || naturalY < 0 || naturalY > img.naturalHeight) {
                console.log('Point out of bounds:', naturalX, naturalY);
                return;
            }

            const color = colorScale.current(d.displacement) || '#FF0000';
            ctx.strokeStyle = color;
            ctx.lineWidth = 3;
            ctx.strokeRect(naturalX - 5, naturalY - 5, 10, 10);
        });

        ctx.restore();
    }, [staticCanvasRef, imageRef, isImageLoaded, updateColorScale]);

    const drawHighlights = useCallback(() => {
        if (!highlightCanvasRef.current || !imageRef.current || !isImageLoaded) return;
        const canvas = highlightCanvasRef.current;
        const ctx = canvas.getContext('2d');
        const img = imageRef.current;

        canvas.width = img.naturalWidth;
        canvas.height = img.naturalHeight;

        ctx.clearRect(0, 0, canvas.width, canvas.height);

        popups.forEach(popup => {
            if (popup.type === 'single') {
                ctx.strokeStyle = '#FFFFFF';
                ctx.lineWidth = 3;
                ctx.beginPath();
                ctx.arc(popup.x, popup.y, 8, 0, 2 * Math.PI);
                ctx.stroke();
            } else if (popup.type === 'multi') {
                popup.points.forEach(point => {
                    ctx.strokeStyle = popup.color;
                    ctx.lineWidth = 3;
                    ctx.beginPath();
                    ctx.arc(point.x, point.y, 8, 0, 2 * Math.PI);
                    ctx.stroke();
                });
            }
        });

        if (selectionMode === 'Polygon' && polygonPoints.length > 0) {
            ctx.strokeStyle = '#FFFFFF';
            ctx.lineWidth = 3;
            ctx.beginPath();
            ctx.moveTo(polygonPoints[0].x, polygonPoints[0].y);
            for (let i = 1; i < polygonPoints.length; i++) {
                ctx.lineTo(polygonPoints[i].x, polygonPoints[i].y);
            }
            if (isDrawing.current && currentMousePos) {
                ctx.lineTo(currentMousePos.x, currentMousePos.y);
            } else {
                ctx.lineTo(polygonPoints[0].x, polygonPoints[0].y);
            }
            ctx.stroke();
        }
    }, [highlightCanvasRef, imageRef, isImageLoaded, popups, selectionMode, polygonPoints, currentMousePos, isDrawing]);

    useEffect(() => {
        if (isImageLoaded && displacementData.length > 0) {
            drawStaticMarkers(displacementData);
        }
    }, [isImageLoaded, displacementData, drawStaticMarkers]);

    useEffect(() => {
        if (isImageLoaded) {
            drawHighlights();
        }
    }, [isImageLoaded, popups, drawHighlights, polygonPoints, currentMousePos]);

    const getPointsWithinPolygon = useCallback((polygon) => {
        const polygonPath = polygon.map(p => [p.x, p.y]);
        const selectedPoints = displacementData.filter(point => 
            d3.polygonContains(polygonPath, [point.x, point.y])
        );
        return selectedPoints;
    }, [displacementData]);

    const clientToNatural = useCallback((clientX, clientY) => {
        if (!imageRef.current) return { x: 0, y: 0 };

        const img = imageRef.current;
        const rect = img.getBoundingClientRect();

        const relativeX = clientX - rect.left;
        const relativeY = clientY - rect.top;

        if (relativeX < 0 || relativeY < 0 || relativeX > rect.width || relativeY > rect.height) {
            return { x: -1, y: -1 };
        }

        const scaleX = img.naturalWidth / rect.width;
        const scaleY = img.naturalHeight / rect.height;

        const naturalX = relativeX * scaleX;
        const naturalY = relativeY * scaleY;

        return { x: naturalX, y: naturalY };
    }, [imageRef]);

    const handleCanvasClick = useCallback((e) => {
        if (selectionMode !== 'Point' || !highlightCanvasRef.current || !imageRef.current) return;

        const { x: naturalClickX, y: naturalClickY } = clientToNatural(e.clientX, e.clientY);

        if (naturalClickX === -1 && naturalClickY === -1) return;

        const radius = 10;
        let closestPoint = null;
        let minDistance = radius;

        displacementData.forEach(point => {
            const dx = point.x - naturalClickX;
            const dy = point.y - naturalClickY;
            const distance = Math.sqrt(dx * dx + dy * dy);

            if (distance <= minDistance) {
                minDistance = distance;
                closestPoint = point;
            }
        });

        if (closestPoint) {
            addSinglePopup(closestPoint, []);
            fetchTimeseriesData(closestPoint);
        }
    }, [selectionMode, highlightCanvasRef, imageRef, displacementData, addSinglePopup, fetchTimeseriesData, clientToNatural]);

    const handlePolygonAddPoint = useCallback((e) => {
        if (selectionMode !== 'Polygon' || !highlightCanvasRef.current || !imageRef.current) return;

        if (e.detail > 1) {
            return;
        }

        if (!isDrawing.current) {
            isDrawing.current = true;
            setPolygonPoints([]);
        }

        const { x: naturalClickX, y: naturalClickY } = clientToNatural(e.clientX, e.clientY);

        if (naturalClickX === -1 && naturalClickY === -1) return;

        addPolygonPoint({ x: naturalClickX, y: naturalClickY });
    }, [selectionMode, addPolygonPoint, isDrawing, highlightCanvasRef, imageRef, setPolygonPoints, clientToNatural]);

    const handlePolygonEnd = useCallback((e) => {
        if (selectionMode !== 'Polygon' || !isDrawing.current || !highlightCanvasRef.current || !imageRef.current) return;

        if (isFinalizing.current) {
            return;
        }
        isFinalizing.current = true;

        isDrawing.current = false;

        if (polygonPoints.length < 3) {
            setError('A polygon requires at least three points.');
            isFinalizing.current = false;
            return;
        }

        const selectedPoints = getPointsWithinPolygon(polygonPoints);

        if (selectedPoints.length === 0) {
            setError('No points selected within the polygon.');
            setPolygonPoints([]);
            isFinalizing.current = false;
            return;
        }

        const sortedPointIds = selectedPoints.map(p => p.pointId).sort();
        const groupId = sortedPointIds.join('-');

        addMultiPopup(selectedPoints, groupId);
        fetchTimeseriesForMultiplePoints(selectedPoints, groupId);

        setPolygonPoints([]);
        setCurrentMousePos(null);
        isFinalizing.current = false;
    }, [
        selectionMode,
        isFinalizing,
        highlightCanvasRef,
        imageRef,
        polygonPoints,
        getPointsWithinPolygon,
        addMultiPopup,
        fetchTimeseriesForMultiplePoints,
        setPolygonPoints,
        setCurrentMousePos,
        setError
    ]);

    const handleMouseMove = useCallback((e) => {
        if (selectionMode !== 'Polygon' || !isDrawing.current || !highlightCanvasRef.current || !imageRef.current) return;

        const { x: naturalMouseX, y: naturalMouseY } = clientToNatural(e.clientX, e.clientY);

        if (naturalMouseX === -1 && naturalMouseY === -1) {
            setCurrentMousePos(null);
            return;
        }

        setCurrentMousePos({ x: naturalMouseX, y: naturalMouseY });
    }, [selectionMode, isDrawing, highlightCanvasRef, imageRef, setCurrentMousePos, clientToNatural]);

    const renderCanvases = useCallback(() => {
        return (
            <CanvasLayer
                staticCanvasRef={staticCanvasRef}
                highlightCanvasRef={highlightCanvasRef}
                selectionMode={selectionMode}
                onClick={selectionMode === 'Point' ? handleCanvasClick : (selectionMode === 'Polygon' ? handlePolygonAddPoint : null)}
                onDoubleClick={selectionMode === 'Polygon' ? handlePolygonEnd : null}
                onMouseMove={selectionMode === 'Polygon' ? handleMouseMove : null}
            />
        );
    }, [
        handleCanvasClick,
        handlePolygonAddPoint,
        handlePolygonEnd,
        handleMouseMove,
        selectionMode,
        staticCanvasRef,
        highlightCanvasRef
    ]);

    return { renderCanvases };
}

export default useCanvasDrawing;
