import { useEffect } from 'react';
import { stratify, tree } from 'd3-hierarchy';
import { Edge, Node, ReactFlowState, useReactFlow, useStore } from 'reactflow';

import { NODE_HEIGHT, NODE_WIDTH } from '../constants';

const SPACE_BETWEEN_NODES = 80;

// initialize the tree layout (see https://observablehq.com/@d3/tree for examples)
const layout = tree<Node>()
  // the node size configures the spacing between the nodes ([width, height])
  .nodeSize([
    NODE_WIDTH + SPACE_BETWEEN_NODES,
    NODE_HEIGHT + SPACE_BETWEEN_NODES,
  ])
  // this is needed for creating equal space between all nodes
  .separation(() => 1);

// the layouting function
// accepts current nodes and edges and returns the layouted nodes with their updated positions
export function layoutNodes(nodes: Node[], edges: Edge[]): Node[] {
  // convert nodes and edges into a hierarchical object for using it with the layout function
  const hierarchy = stratify<Node>()
    .id(d => d.id)
    // get the id of each node by searching through the edges
    // this only works if every node has one connection
    .parentId((d: Node) => edges.find((e: Edge) => e.target === d.id)?.source)(
    nodes,
  );

  // run the layout algorithm with the hierarchy data structure
  const root = layout(hierarchy);

  // convert the hierarchy back to react flow nodes (the original node is stored as d.data)
  // we only extract the position from the d3 function
  return root
    .descendants()
    .map(d => ({ ...d.data, position: { x: d.x, y: d.y } }));
}

// this is the store selector that is used for triggering the layout, this returns the number of nodes once they change
const nodeCountSelector = (state: ReactFlowState) => state.nodeInternals.size;

export default function useLayout() {
  // we are using nodeCount as the trigger for the re-layouting
  // whenever the nodes length changes, we calculate the new layout
  const nodeCount = useStore(nodeCountSelector);

  const { fitView, getEdges, getNode, getNodes, setEdges, setNodes } =
    useReactFlow();

  useEffect(() => {
    // get the current nodes and edges
    const nodes = getNodes();
    const edges = getEdges();

    // run the layout and get back the nodes with their updated positions
    const targetNodes = layoutNodes(nodes, edges);

    return setNodes(targetNodes);
  }, [nodeCount, getEdges, getNodes, getNode, setNodes, fitView, setEdges]);
}
