import Plotly, { PlotMouseEvent } from "plotly.js-basic-dist-min";
import createPlotlyComponent from "react-plotly.js/factory";
const Plot = createPlotlyComponent(Plotly);

import { AlignmentRecord } from "@/api/alignment/schemas";
import {
  UMAPDatum,
  UNANNOTATED_CLIP_LABEL,
} from "@/api/protein_search/schemas";
import { QUERY_COLOR } from "@/utils/colors";
import { cn as classNames } from "@/utils/strings";
import { useState } from "react";
import { Link } from "react-router-dom";

export const UMAP = ({
  className,
  onClick,
  queryUMAPDatum,
  matchUMAPData,
  alignmentData,
  isFetchingAlignment,
  alignmentError,
}: {
  queryUMAPDatum: UMAPDatum;
  matchUMAPData: UMAPDatum[];
  alignmentData?: AlignmentRecord[];
  isFetchingAlignment: boolean;
  alignmentError: Error | null;
  className?: string;
  onClick?: (cds_shorthand: string | null) => void;
}) => {
  const [colorScheme, setColorScheme] = useState<"cosine_sim" | "sequence_sim">(
    "cosine_sim",
  );
  const [hoveredCdsShorthand, setHoveredCdsShorthand] = useState<string | null>(
    null,
  );

  const data = [queryUMAPDatum, ...matchUMAPData];
  const trace: Partial<Plotly.ScatterData> = {
    name: "", // if you don't pass a name, plotly will use the string "trace0" in places like the legend/on hover to refer to the dataset
    x: data.map((record) => record.x),
    y: data.map((record) => record.y),
    customdata: data.map((record) => {
      if (record.cdsId === "query") {
        return null;
      } else {
        return record.cdsId.cds_shorthand;
      }
    }),
    hovertemplate: data.map((record) => {
      let title: string;
      if (record.cdsId === "query") {
        title = "Query";
      } else {
        title =
          record.clipAnnotation?.clipDescription ?? UNANNOTATED_CLIP_LABEL;
      }
      return `${title} | Click to view`;
    }),
    mode: "markers" as const,
    type: "scatter" as const,
    marker: {
      size: data.map((record) => {
        if (record.cdsId === "query") {
          return 20;
        } else {
          return 10;
        }
      }),
      line: {
        width: data.map((record) => {
          if (record.cdsId === "query") {
            return 0;
          } else if (hoveredCdsShorthand === record.cdsId.cds_shorthand) {
            return 3;
          } else {
            return 0;
          }
        }),
        color: "black",
      },

      symbol: data.map((record) => {
        if (record.cdsId === "query") {
          return "circle";
        } else {
          return "circle";
        }
      }),

      color: data.map((record) => {
        if (record.cdsId === "query") {
          return QUERY_COLOR; // matches FoldingCard query structure color
        }

        if (colorScheme === "cosine_sim") {
          return record.cos_sim_color;
        } else {
          const alignmentRecord = alignmentData?.find((ar) => {
            if (record.cdsId === "query") {
              return false;
            } else {
              return record.cdsId.cds_shorthand === ar.refCdsId.cds_shorthand;
            }
          });
          if (alignmentRecord) {
            // Color based on percent identity
            const percentIdentity = alignmentRecord.percentIdentity;
            // Adjust color scale as needed
            return `rgb(${
              percentIdentity * 2.55
            }, 0, ${255 - percentIdentity * 2.55})`;
          } else {
            return "gray"; // Default color if no alignment data found
          }
        }
      }),
    },
  };
  let minValue = Math.min(
    ...matchUMAPData.map((record) => record.cos_sim_score),
  );
  let maxValue = Math.max(
    ...matchUMAPData.map((record) => record.cos_sim_score),
  );
  if (colorScheme === "sequence_sim" && alignmentData !== undefined) {
    minValue = Math.min(
      ...alignmentData.map((record) => record.percentIdentity),
    );
    maxValue = Math.max(
      ...alignmentData.map((record) => record.percentIdentity),
    );
  }

  const layout: Partial<Plotly.Layout> = {
    title: "",
    titlefont: {
      size: 16,
      color: "#ffffff",
    },

    paper_bgcolor: "rgb(252,251,250)", // bg-molstar
    plot_bgcolor: "rgb(252,251,250)",
    // remove axes/lines
    yaxis: {
      color: "rgb(252,251,250)",
    },
    xaxis: {
      color: "rgb(252,251,250)",
    },
    dragmode: "pan",
    margin: { b: 5, t: 5, l: 5, r: 5 },
    uirevision: "0", // ensures that zoom/layout does not change on update
    autosize: true,
  };
  const config: Partial<Plotly.Config> = {
    responsive: true,
    // hide all plotly chrome
    displaylogo: false,
    scrollZoom: true,
    modeBarButtonsToRemove: [
      "autoScale2d",
      "hoverClosestCartesian",
      "hoverCompareCartesian",
      "lasso2d",
      "select2d",
      "sendDataToCloud",
      "zoomIn2d",
      "zoomOut2d",
      "toImage",
    ],
  };
  return (
    <div className={classNames(className, "flex flex-col gap-2")}>
      <div className="flex flex-col justify-start ">
        <div className="flex items-start gap-2 font-semibold">
          UMAP
          <p className="text-noir-400">{matchUMAPData.length} results</p>
        </div>
        <p className="text-left text-xs font-light">
          <Link
            className={classNames(
              "border-b border-dashed border-brand-600",
              "transition-colors duration-200 ease-in-out hover:text-brand-600",
            )}
            to="https://umap-learn.readthedocs.io/en/latest/"
          >
            Visualization
          </Link>{" "}
          of protein sequences, represented by vector embeddings.
        </p>
        <p className="text-left text-xs font-light">
          Closer points may indicate similarities in sequence or function.
        </p>
      </div>

      <div
        className={classNames(
          "flex flex-col overflow-hidden rounded-md border px-4",
          "h-full bg-molstar",
        )}
      >
        <span className="flex justify-between gap-2 ">
          <label className="label flex flex-col items-start">
            <h4 className=" pb-1 pl-1 text-xs">Color by</h4>
            <select
              value={colorScheme}
              className="rounded-md border bg-molstar !text-xs"
              onChange={(e) => {
                if (isFetchingAlignment || alignmentError) {
                  console.log("Can't switch to sequence_sim");
                  setColorScheme("cosine_sim");
                }
                setColorScheme(e.target.value as "cosine_sim" | "sequence_sim");
              }}
            >
              <option value="cosine_sim">Cosine Similarity</option>
              <option value="sequence_sim" disabled={!alignmentData}>
                Sequence Similarity {isFetchingAlignment && "(loading)"}
                {alignmentError && "(failed)"}
              </option>
            </select>
          </label>
          <ColorLegend
            colorScheme={colorScheme}
            className="w-64 pt-2"
            minValue={minValue}
            maxValue={maxValue}
          />
        </span>
        <Plot
          className="h-full min-h-80"
          data={[trace]}
          layout={layout}
          config={config}
          useResizeHandler={true}
          onClick={(click: PlotMouseEvent) => {
            const cds_shorthand = click.points[0].customdata as string | null; // customdata is the cds shorthand
            onClick?.(cds_shorthand);
          }}
          onHover={(hover: PlotMouseEvent) => {
            const cds_shorthand = hover.points[0].customdata as string | null;
            setHoveredCdsShorthand(cds_shorthand);
          }}
          onUnhover={() => {
            setHoveredCdsShorthand(null);
          }}
        />
      </div>
    </div>
  );
};

const ColorLegend = ({
  colorScheme,
  className,
  minValue,
  maxValue,
}: {
  colorScheme: "cosine_sim" | "sequence_sim";
  className?: string;
  minValue: number;
  maxValue: number;
}) => {
  const sequenceColorScale = [
    "rgb(0, 0, 255)",
    "rgb(125, 0, 125)",
    "rgb(255, 0, 0)",
  ];
  const minDisplayValue =
    colorScheme === "sequence_sim"
      ? `${minValue.toFixed(0)}%`
      : `${minValue.toFixed(2)}`;
  const maxDisplayValue =
    colorScheme === "sequence_sim"
      ? `${maxValue.toFixed(0)}%`
      : `${maxValue.toFixed(2)}`;
  return (
    <div className={classNames(className, "relative flex-col")}>
      <header className="just flex items-center gap-2 pb-2 ">
        <h4 className="mr-auto text-xs">Similarity</h4>
        {colorScheme === "sequence_sim" && (
          <label className="flex items-center gap-1 text-xs">
            <div className="h-1 w-1 rounded-full bg-[gray]" />
            No alignment
          </label>
        )}
        <label className="flex items-center gap-1 text-xs">
          <div className="h-1 w-1 rounded-full bg-[#dc2626]" />
          Query
        </label>
      </header>

      <div
        className="h-1 rounded-md"
        style={{
          background:
            colorScheme === "cosine_sim"
              ? "linear-gradient(to right, #440154, #414487, #2a788e, #22a884, #7ad151, #fde725)"
              : `linear-gradient(to right, ${sequenceColorScale.join(", ")})`,
        }}
      />
      <span className="absolute left-0 text-xs">{minDisplayValue}</span>
      <span className="absolute right-0 text-xs">{maxDisplayValue}</span>
    </div>
  );
};
