import clsx from 'clsx'
import { scaleOrdinal } from 'd3'
import {
  SankeyNodeMinimal,
  sankey,
  sankeyJustify,
  sankeyLinkHorizontal,
} from 'd3-sankey'
import cloneDeep from 'lodash.clonedeep'
import uniq from 'lodash.uniq'
import { nanoid } from 'nanoid'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'

import { getFallbackImage } from '@dao-dao/utils'

import { useUpdatingRef } from '../hooks'
import { ImageTextDisplay } from './ImageTextDisplay'

export type SankeyDataNode = {
  /**
   * Unique identifer among all nodes.
   */
  id: string
  /**
   * Optional label.
   */
  label?: string
  /**
   * Optional image URL.
   */
  imageUrl?: string
  /**
   * Optional caption.
   */
  caption?: string
}

export type SankeyDataLink = {
  /**
   * Unique identifier among all links.
   */
  id: string
  /**
   * Source node ID.
   */
  source: string
  /**
   * Target node ID.
   */
  target: string
  /**
   * Link value.
   */
  value: number
  /**
   * Style of the link.
   */
  style?: 'default' | 'dim' | 'bright'
}

export type SankeyData = {
  nodes: SankeyDataNode[]
  links: SankeyDataLink[]
}

export type SankeyProps = {
  /**
   * Data.
   */
  data: SankeyData
  /**
   * The header labels. If undefined, will not show the node labels on either
   * side, just showing the graph.
   */
  header?: {
    /**
     * Source category label.
     */
    source: string
    /**
     * Target category label.
     */
    target: string
  }
  /**
   * Minimum height. If the height at the ideal ratio given the responsive width
   * is less than this value, the height will be set to this value. Defaults to
   * 100.
   */
  minHeight?: number
  /**
   * Maximum height. If the height at the ideal ratio given the responsive width
   * is greater than this value, the height will be set to this value. Defaults
   * to 400.
   */
  maxHeight?: number
  /**
   * The ideal ratio of width to height. Defaults to 1.5.
   */
  idealRatio?: number
  /**
   * The width of the nodes. Defaults to 4.
   */
  nodeWidth?: number
  /**
   * Disable hover effects. Defaults to false.
   */
  disableHover?: boolean
  /**
   * Optional hover handler.
   */
  onHover?: (link?: SankeyDataLink) => void
  /**
   * Optional click handler.
   */
  onClick?: (link: SankeyDataLink) => void
  /**
   * Optionally set link ID to highlight.
   */
  highlightLinkId?: string
  /**
   * Optional class name.
   */
  className?: string
}

const HEADER_HEIGHT = 40

export const Sankey = ({
  data,
  header,
  minHeight = 200,
  maxHeight = 300,
  idealRatio = 1.8,
  nodeWidth = 4,
  disableHover = false,
  onHover,
  onClick,
  highlightLinkId: _highlightLinkId,
  className,
}: SankeyProps) => {
  const containerRef = useRef<HTMLDivElement | null>(null)
  const [width, setWidth] = useState(0)
  const height = Math.max(
    minHeight,
    Math.min(Math.round(width / idealRatio), maxHeight)
  )

  const handleResize = () =>
    containerRef.current && setWidth(containerRef.current.clientWidth)

  const setContainerRef = useCallback((ref: HTMLDivElement | null) => {
    containerRef.current = ref
    handleResize()
  }, [])

  useEffect(() => {
    const timer = setInterval(handleResize, 500)
    window.addEventListener('resize', handleResize)

    return () => {
      clearInterval(timer)
      window.removeEventListener('resize', handleResize)
    }
  }, [])

  const [hoveringId, setHoveringId] = useState<string>()
  const highlightLinkId = _highlightLinkId || hoveringId

  const onClickRef = useUpdatingRef(onClick)
  const onHoverRef = useUpdatingRef(onHover)

  const { nodes, nodePadding, allNodes, allLinks } = useMemo(() => {
    if (!data.nodes.length || !data.links.length) {
      return {
        nodes: [],
        nodePadding: 0,
        allNodes: [],
        allLinks: [],
      }
    }

    const colorScale = scaleOrdinal<string>()
      .domain(data.nodes.map((n) => n.id))
      .range(COLORS)

    const sourceNodes = uniq(data.links.map((l) => l.source))
    const nodePadding =
      sourceNodes.length <= 1
        ? 0
        : Math.round((height * 0.2) / (sourceNodes.length - 1))
    const sankeyGen = sankey<SankeyDataNode, SankeyDataLink>()
      .nodeWidth(nodeWidth)
      .nodePadding(nodePadding)
      .extent([
        [0, 0],
        [width, height],
      ])
      .nodeId((node) => node.id)
      .nodeAlign(sankeyJustify)

    const { nodes, links } = sankeyGen(cloneDeep(data))

    const sankeyId = nanoid()

    const allNodes = nodes.map((node) => {
      return (
        <g key={node.index}>
          <rect
            fill={colorScale(node.id)}
            fillOpacity={0.8}
            height={node.y1! - node.y0!}
            width={sankeyGen.nodeWidth()}
            x={node.x0}
            y={node.y0}
          />
        </g>
      )
    })

    const allLinks = links.map((link, i) => {
      const linkGenerator = sankeyLinkHorizontal()
      const path = linkGenerator(link)

      const gradientId = `${sankeyId}-gradient-${i}`

      const style =
        highlightLinkId && highlightLinkId === link.id ? 'bright' : link.style

      const strokeOpacity =
        style === 'dim' ? 0.1 : style === 'bright' ? 0.5 : 0.3

      return (
        path && (
          <g key={i} style={{ isolation: 'isolate' }}>
            <path
              className={clsx(
                'transition-all',
                disableHover && 'cursor-default',
                onClickRef.current &&
                  'cursor-pointer opacity-100 active:opacity-80'
              )}
              d={path}
              fill="none"
              onClick={() => onClickRef.current?.(link)}
              onMouseOut={
                disableHover
                  ? undefined
                  : () => {
                      setHoveringId(undefined)
                      onHoverRef.current?.(undefined)
                    }
              }
              onMouseOver={
                disableHover
                  ? undefined
                  : () => {
                      setHoveringId(link.id)
                      onHoverRef.current?.(link)
                    }
              }
              stroke={`url(#${gradientId})`}
              strokeOpacity={strokeOpacity}
              strokeWidth={link.width}
              style={{
                mixBlendMode: 'multiply',
              }}
            />
            <linearGradient
              gradientUnits="userSpaceOnUse"
              id={gradientId}
              x1={
                (
                  link.source as SankeyNodeMinimal<
                    SankeyDataNode,
                    SankeyDataLink
                  >
                ).x1
              }
              x2={
                (
                  link.target as SankeyNodeMinimal<
                    SankeyDataNode,
                    SankeyDataLink
                  >
                ).x0
              }
            >
              <stop
                offset="0%"
                stopColor={colorScale((link.source as SankeyDataNode).id)}
              />
              <stop
                offset="100%"
                stopColor={colorScale((link.target as SankeyDataNode).id)}
              />
            </linearGradient>
          </g>
        )
      )
    })

    return {
      nodes,
      nodePadding,
      allNodes,
      allLinks,
    }
  }, [
    data,
    disableHover,
    height,
    highlightLinkId,
    nodeWidth,
    onClickRef,
    onHoverRef,
    width,
  ])

  return (
    <div
      className={clsx(
        'flex flex-row items-stretch justify-between gap-4',
        className
      )}
    >
      {header && (
        <div className="flex flex-col shrink-0">
          <div className="flex flex-col" style={{ height: HEADER_HEIGHT }}>
            <p className="caption-text pb-1 border-b border-border-primary border-dashed">
              {header.source}
            </p>
          </div>

          {nodes
            .filter((node) => node.depth === 0)
            .map((node, index) => (
              <div
                key={node.id}
                className="flex flex-col justify-center items-start"
                style={{
                  marginTop: index === 0 ? node.y0 : nodePadding,
                  height: node.y1! - node.y0!,
                }}
              >
                {node.label && (
                  <ImageTextDisplay
                    key={node.id}
                    caption={node.caption}
                    imageUrl={node.imageUrl || getFallbackImage(node.id)}
                    label={node.label}
                  />
                )}
              </div>
            ))}
        </div>
      )}

      <div
        className="min-w-0 grow"
        ref={setContainerRef}
        style={header && { paddingTop: HEADER_HEIGHT }}
      >
        <svg height={height} width={width}>
          {allNodes}
          {allLinks}
        </svg>
      </div>

      {header && (
        <div className="flex flex-col shrink-0">
          <div className="flex flex-col" style={{ height: HEADER_HEIGHT }}>
            <p className="caption-text pb-1 border-b border-border-primary border-dashed">
              {header.target}
            </p>
          </div>

          {nodes
            .filter((node) => node.depth === 1)
            .map((node, index) => (
              <div
                key={node.id}
                className="flex flex-col justify-center items-start"
                style={{
                  marginTop: index === 0 ? node.y0 : nodePadding,
                  height: node.y1! - node.y0!,
                }}
              >
                {node.label && (
                  <ImageTextDisplay
                    key={node.id}
                    caption={node.caption}
                    imageUrl={node.imageUrl || getFallbackImage(node.id)}
                    label={node.label}
                  />
                )}
              </div>
            ))}
        </div>
      )}
    </div>
  )
}

const COLORS = [
  '#9e59c6',
  '#627eeb',
  '#c95951',
  '#f19a76',
  '#39a699',
  '#a9e569',
]
