import React, { useRef, useState } from "react";
import { Badge, Stack, Table } from '@mantine/core'
import * as d3 from 'd3';
import SvgComponent from "./SvgComponent";

interface DataPoint {
    id: number
    x: number
    y: number
    gene: string
    comparison: string
    geo_link: string
    inline: boolean
    greyed_out?: boolean
}

interface VolcanoPlotProps {
    data: DataPoint[]
    width: number
    height: number
    filters: any[] | undefined
    onPointSelected: (d: DataPoint) => void
}

const VolcanoPlot: React.FC<VolcanoPlotProps> = ({ data, width, height, filters=undefined, onPointSelected }) => {
    const [selectedPoint, setSelectedPoint] = useState<number | null>(null)
    const zoomTransformRef = useRef(d3.zoomIdentity);

    const marginTop = 10;
    const marginRight = 0;
    const marginBottom = 20;
    const marginLeft = 20;

    const meets = (filter, d) => {
        const is_inline = filter.inline !== undefined ? d.inline === filter.inline : true;
        const for_gene = filter.gene ? d.gene === filter.gene : true;
    
        if (filter.inline !== undefined && filter.gene) {
            return is_inline && for_gene;
        }
        if (filter.inline !== undefined) {
            return is_inline;
        }
        if (filter.gene) {
            return for_gene;
        }
        return false;
    };    

    data = !filters? data: data.map(
        (d) => {
            d.greyed_out = true
            if(filters.find(f => meets(f, d)))
                d.greyed_out = false
            return d
        }
    )

    const render = (svg: d3.Selection<SVGSVGElement, unknown, null, undefined>) => {
        const gDot = svg.append("g")
            .attr("fill", "none")
            .attr("stroke-linecap", "round")
            .attr("transform", `translate(${marginLeft},${marginTop})`);

        const gGrid = svg.append("g");

        const XLimit = Math.max(Math.abs(Math.min(...data.map(d => d.x))), Math.abs(Math.max(...data.map(d => d.x))))
        const YLimit = Math.abs(Math.max(...data.map(d => d.y)))

        const x = d3.scaleLinear()
            .domain([-XLimit, XLimit])
            .range([marginLeft, width - marginRight])
        let zx = x

        const y = d3.scaleLinear()
            .domain([0, YLimit])
            .range([height - marginBottom, marginTop])
        let zy = y

        const hoverLineGroup = svg.append("g")
            .style("display", "none");

        const vHoverLine = hoverLineGroup.append("line")
            .attr("stroke", "#73C2FB")
            .attr("stroke-width", 2)
            .attr("opacity", 0.3);

        const hHoverLine = hoverLineGroup.append("line")
            .attr("stroke", "#73C2FB")
            .attr("stroke-width", 2)
            .attr("opacity", 0.3);

        gDot.selectAll("path")
            .data(data)
            .join("path")
            .attr("d", d => `M${x(d.x)},${y(d.y)}h0`)
            .attr("stroke", d => d.greyed_out? '#ccc' : d.x >= 0 ? 'green' : 'red')
            .attr("opacity", d => selectedPoint === d.id ? 1 : 0.5)
            .on('click', (event, d) => {
                if(d.greyed_out) return
                if (selectedPoint === d.id) setSelectedPoint(null);
                else setSelectedPoint(d.id);
                onPointSelected(d)
            })
            .on("mouseover", (event, d) => {
                if(d.greyed_out) return
                hoverLineGroup.style("display", null);
                vHoverLine.attr("x1", zx(d.x))
                    .attr("x2", zx(d.x))
                    .attr("y1", zy(zy.domain()[0]))
                    .attr("y2", zy(d.y));
                hHoverLine
                    .attr("x1", zx(zx.domain()[0]))
                    .attr("x2", zx(d.x))
                    .attr("y1", zy(d.y))
                    .attr("y2", zy(d.y));
            })
            .on("mouseout", (event, d) => {
                if(d.greyed_out) return
                hoverLineGroup.style("display", "none");
            });

        let [vSelectLine, hSelectLine] = [null, null]
        if (selectedPoint) {
            const d = data.find(d => d.id === selectedPoint)
            vSelectLine = svg.append("line")
                .attr("x1", x(d.x))
                .attr("x2", x(d.x))
                .attr("y1", y(d.y))
                .attr("y2", y(0))
                .attr("stroke", "#73C2FB")
                .attr("stroke-width", 1)

            hSelectLine = svg.append("line")
                .attr("x1", x(d.x))
                .attr("x2", x(0))
                .attr("y1", y(d.y))
                .attr("y2", y(d.y))
                .attr("stroke", "#73C2FB")
                .attr("stroke-width", 1)
        }

        const vpivot = svg.append("line")
            .attr("x1", x(0))
            .attr("x2", x(0))
            .attr("y1", 0)
            .attr("y2", height)
            .attr("stroke", "black")
            .attr("stroke-width", 1)
            .attr("stroke-dasharray", "4 3");

        const hpivot = svg.append("line")
            .attr("x1", 0)
            .attr("x2", width)
            .attr("y1", y(0))
            .attr("y2", y(0))
            .attr("stroke", "black")
            .attr("stroke-width", 1)
            .attr("stroke-dasharray", "4 3");

        const gx = svg.append("g")
            .attr("transform", `translate(0,${height - marginBottom})`)
            .call(g => g.append("text")
                .attr("x", width / 2)
                .attr("y", marginBottom - 5)
                .attr("fill", "currentColor")
                .attr("text-anchor", "middle")
                .style("font-size", "14px")
                .text('log2 fold change'));

        const gy = svg.append("g")
            .attr("transform", `translate(${marginLeft},${0})`)
            .call(g => g.append("text")
                .attr("x", -height / 2)
                .attr("y", -marginLeft + 10)
                .attr("transform", "rotate(-90)")
                .attr("fill", "currentColor")
                .attr("text-anchor", "middle")
                .style("font-size", "14px")
                .text('-log10 pvalue'));

        const grid = (g, x, y) => g
            .attr("stroke", "currentColor")
            .attr("stroke-opacity", 0.1)
            .call(g => g
                .selectAll(".x")
                .data(x.ticks(width / 50))
                .join(
                    enter => enter.append("line").attr("class", "x").attr("y2", height),
                    update => update,
                    exit => exit.remove()
                )
                .attr("x1", d => marginLeft + 0.5 + x(d))
                .attr("x2", d => marginLeft + 0.5 + x(d)))
            .call(g => g
                .selectAll(".y")
                .data(y.ticks(height / 50))
                .join(
                    enter => enter.append("line").attr("class", "y").attr("x2", width),
                    update => update,
                    exit => exit.remove()
                )
                .attr("y1", d => marginTop + 0.5 + y(d))
                .attr("y2", d => marginTop + 0.5 + y(d)));

        const xAxis = (g, x) => g
            .call(d3.axisTop(x).ticks(width / 50))

        const yAxis = (g, y) => g
            .call(d3.axisRight(y).ticks(height / 50))

        function zoomed({ transform }) {
            zx = transform.rescaleX(x).interpolate(d3.interpolateRound);
            zy = transform.rescaleY(y).interpolate(d3.interpolateRound);
            zoomTransformRef.current = transform;
            gDot.attr("transform", transform).attr("stroke-width", 5 / transform.k);
            hpivot.attr("y1", zy(0)).attr("y2", zy(0));
            vpivot.attr("x1", zx(0)).attr("x2", zx(0));
            if (selectedPoint) {
                const d = data.find(d => d.id === selectedPoint)
                vSelectLine.attr("x1", zx(d.x))
                    .attr("x2", zx(d.x))
                    .attr("y1", zy(zy.domain()[0]))
                    .attr("y2", zy(d.y));
                hSelectLine.attr("x1", zx(zx.domain()[0]))
                    .attr("x2", zx(d.x))
                    .attr("y1", zy(d.y))
                    .attr("y2", zy(d.y));
            }
            gx.call(xAxis, zx);
            gy.call(yAxis, zy);
            gGrid.call(grid, zx, zy);
        }

        const zoom = d3.zoom()
            .scaleExtent([0.5, 32])
            .on("zoom", zoomed);

        svg.call(zoom).call(zoom.transform, zoomTransformRef.current);

    };

    return (
        <SvgComponent
            data={data}
            render={render}
            width={width}
            height={height} />
    );
};


const DiffExprDetails: React.FC<{ data: DataPoint[], filters: any[], width: number, height: number }> = ({ data, filters, width, height }) => {
    const [tabularData, setTabularData] = useState<DataPoint>(null)

    const onPointSelected = (d: DataPoint) => {
        if (tabularData && d.id === tabularData.id) setTabularData(null);
        else setTabularData(d)
    }

    return (
        <Stack>
            <VolcanoPlot
                data={data}
                filters={filters}
                width={width - 50}
                height={height}
                onPointSelected={onPointSelected}
            />
            {tabularData ? (
                <Table highlightOnHover>
                    <Table.Thead>
                        <Table.Tr>
                            <Table.Th>Gene</Table.Th>
                            <Table.Th>Comparison</Table.Th>
                            <Table.Th>GEO Link</Table.Th>
                        </Table.Tr>
                    </Table.Thead>
                    <Table.Tbody>
                        <Table.Tr key={tabularData.id}>
                            <Table.Td>{tabularData.gene}</Table.Td>
                            <Table.Td>{tabularData.comparison}</Table.Td>
                            <Table.Td><a href={tabularData.geo_link} >{tabularData.geo_link}</a></Table.Td>
                        </Table.Tr>
                    </Table.Tbody>
                </Table>
            ) : <></>}
        </Stack>
    );
};

export default DiffExprDetails;
