import React, { useState, useEffect } from 'react';
import * as d3 from 'd3';
import { sankey as d3Sankey, sankeyLinkHorizontal } from 'd3-sankey';
import SvgComponent from './SvgComponent';

const SankeyChart = ({ data, width = 1200, height = 1600, onRectSelected }) => {
    const [selectedLinks, setSelectedLinks] = useState([])

    useEffect(() => {
        onRectSelected(selectedLinks)
    }, [selectedLinks])

    const isSelectedLink = d => selectedLinks.find(l => d.index === l.index)

    const margin = { top: 10, right: 10, bottom: 10, left: 10 };
    const sankeyWidth = width - margin.left - margin.right;
    const sankeyHeight = height - margin.top - margin.bottom;
    const nodeWidth = 12
    const nodePadding = 10

    const render = (svg: d3.Selection<SVGSVGElement, unknown, null, undefined>) => {
        const sankey = d3Sankey()
            .nodeWidth(nodeWidth)
            .nodePadding(nodePadding)
            .extent([[0, 0], [sankeyWidth, sankeyHeight]])
            .nodeSort(null)
            .linkSort(null);

        const { nodes, links } = sankey(data);

        const totalFlow = d3.sum(data.nodes, d => d.totalValue);
        const totalNodePadding = (data.nodes.length - 1) * nodePadding;

        var yScale = d3.scaleLinear()
            .domain([0, totalFlow])
            .range([0, sankeyHeight - totalNodePadding]);

        var nodesOffsets = {}
        var lastLayer = 0
        data.nodes.forEach(function (node) {
            if (!(node.layer in nodesOffsets)) nodesOffsets[node.layer] = 0
            node.y0 = nodesOffsets[node.layer];
            node.y1 = node.y0 + yScale(node.totalValue);
            nodesOffsets[node.layer] = node.y1 + nodePadding;
            lastLayer = node.layer > lastLayer ? node.layer : lastLayer;
        });

        // TODO: revisit this code
        var linksOffsets = {}
        data.links.forEach(function (link) {
            link.width = yScale(link.value)
            if (!(link.source.name in linksOffsets)) linksOffsets[link.source.name] = { right: link.source.y0, left: link.source.y0 }
            if (!(link.target.name in linksOffsets)) linksOffsets[link.target.name] = { right: link.target.y0, left: link.target.y0 }
            const offset0 = linksOffsets[link.source.name].left
            link.y0 = link.source.y1 - offset0 >= link.width ? offset0 + link.width / 2 : link.source.y1 - link.width / 2
            linksOffsets[link.source.name].left = Math.max(linksOffsets[link.source.name].left, link.y0 + link.width / 2);
            const offset1 = linksOffsets[link.target.name].right
            link.y1 = link.target.y1 - offset1 >= link.width ? offset1 + link.width / 2 : link.target.y1 - link.width / 2
            linksOffsets[link.target.name].right = Math.max(linksOffsets[link.target.name].right, link.y1 + link.width / 2);
        });

        const graphGroup = svg.append('g')
            .attr('transform', `translate(${margin.left},${margin.top})`);

        svg.append('defs')
            .append('pattern')
            .attr('id', 'diagonalHatch')
            .attr('patternUnits', 'userSpaceOnUse')
            .attr('width', 10)
            .attr('height', 10)
            .append('path')
            .attr('d', 'M0,0 l10,10')
            .attr('stroke', '#000')
            .attr('stroke-width', 1);

        const link = graphGroup.append('g')
            .selectAll('.link')
            .data(links)
            .enter()
            .append('path')
            .attr('class', 'link')
            .attr('d', sankeyLinkHorizontal())
            .style('fill', 'none')
            .style('stroke', (d) => {
                return isSelectedLink(d) ? d.source.color : "#000"
            })
            .style('stroke-width', d => Math.max(1, d.width))
            .style('opacity', (d) => {
                if(isSelectedLink(d))
                    return 0.3
                return 0.1
            })
            .on("click", function (event, d) {
                if (isSelectedLink(d)) setSelectedLinks(selectedLinks.filter(l => d.index !== l.index))
                else setSelectedLinks([...selectedLinks, d]);
            })
            .on("mouseover", function (event, d) {
                if (isSelectedLink(d)) {
                    d3.select(this).style('opacity', 0.4)
                }
                else {
                    d3.select(this).style("stroke", d.source.color);
                }
            })
            .on("mouseout", function (event, d) {
                if (isSelectedLink(d)) {
                    d3.select(this).style('opacity', 0.2)
                }
                else {
                    d3.select(this).style("stroke", "#000");
                }
            });

        const node = graphGroup.append('g')
            .selectAll('.node')
            .data(nodes)
            .enter()
            .append('g')
            .attr('class', 'node')
            .attr('transform', d => `translate(${d.x0},${d.y0})`)


        node.append('rect')
            .attr('height', d => d.y1 - d.y0)
            .attr('width', sankey.nodeWidth())
            .style('fill', d => d.color)
            .style('opacity', "0")

        node.each(function (d) {
            const totalHeight = d.y1 - d.y0;
            const rect = d3.select(this);

            if (d.positiveValue !== undefined && d.negativeValue !== undefined) {
                const positiveHeight = totalHeight * (d.positiveValue / d.totalValue);

                const negativeHeight = totalHeight * (d.negativeValue / d.totalValue);

                const zeroHeight = totalHeight * (d.zeroValue / d.totalValue);

                rect.append('rect')
                    .attr('y', 0)
                    .attr('height', positiveHeight)
                    .attr('width', sankey.nodeWidth())
                    .style('fill', d => d.color)
                    .style('opacity', '1');

                if (negativeHeight > 0) {
                    rect.append('rect')
                        .attr('y', positiveHeight)
                        .attr('height', negativeHeight)
                        .attr('width', sankey.nodeWidth())
                        .style('fill', d => d.color)
                        .style('fill-opacity', '0.7')
                        .style('stroke', 'none');

                    rect.append('rect')
                        .attr('y', positiveHeight)
                        .attr('height', negativeHeight)
                        .attr('width', sankey.nodeWidth())
                        .style('fill', 'url(#diagonalHatch)')
                        .style('stroke', 'none');
                }

                if (zeroHeight > 0) {
                    rect.append('rect')
                        .attr('y', positiveHeight)
                        .attr('height', zeroHeight)
                        .attr('width', sankey.nodeWidth())
                        .style('fill', d => d.color)
                        .style('fill-opacity', '0.2')
                        .style('stroke', 'none');
                }
            }
        });

        node.append('text')
            .attr('x', -6)
            .attr('y', d => (d.y1 - d.y0) / 2)
            .attr('dy', '.35em')
            .attr('text-anchor', 'end')
            .text(d => d.hasLabel ? d.name : '')
            .filter(d => d.x0 < sankeyWidth / 2)
            .attr('x', 6 + sankey.nodeWidth())
            .attr('text-anchor', 'start');
    };

    return (
        <SvgComponent
            render={render}
            data={data}
            width={width}
            height={height} />
    );
};

export default SankeyChart;
