import {
  ColumnDef,
  flexRender,
  getCoreRowModel,
  Header,
  OnChangeFn,
  SortingState,
  useReactTable,
} from '@tanstack/react-table';
import { useVirtualizer } from '@tanstack/react-virtual';
import { Flex, Typography } from 'antd';
import sum from 'lodash/sum';
import React, {
  ForwardedRef,
  forwardRef,
  useCallback,
  useEffect,
  useImperativeHandle,
  useLayoutEffect,
  useMemo,
  useRef,
  useState,
} from 'react';

import { IconSortingAsc, IconSortingDefault, IconSortingDesc } from '@assets';
import { useInfiniteScroll } from '@components/InfiniteTable/useInfiniteScroll';
import { useStorage } from '@hooks/useStorage';

import * as S from './styled';
import { type InfiniteTableRef } from './types';

type TableItem<Entity> = {
  page: number;
  isEmpty?: boolean;
} & Entity;

const DEFAULT_CELL_HEIGHT = 40;
type InfiniteTableProps<Entity, TData extends TableItem<Entity>> = {
  /**
   * Columns definition for React Table.
   * @see https://tanstack.com/table/v8/docs/api/core/column-def
   */
  columns: ColumnDef<TData>[];
  /**
   * Data to be rendered in the table.
   */
  data: TData[];
  /**
   * Total amount of items in the table.
   */
  totalItems: number;
  /**
   * Amount of items per page.
   */
  perPage: number;
  /**
   * Pages that are currently visible.
   */
  visiblePages: number[];
  /**
   * Whether the table is loading the first time.
   */
  isLoading?: boolean;
  /**
   * Whether the table is fetching the next or prev page.
   */
  isFetching?: boolean;
  /**
   * Height of a single cell.
   */
  cellHeight?: number;
  /**
   * Initial scroll position. Needed for ScrollRestoration component,
   * you don't need to set it manually.
   */
  scrollTop?: number;
  /**
   * Whether the table has more pages.
   */
  hasNextPage?: boolean;
  /**
   * Whether the table has a previous page.
   */
  hasPrevPage?: boolean;
  /**
   * Pages that are currently in a pending state.
   */
  loadingPages?: number[];
  /**
   * The number of pages to render before and after the visible area.
   */
  overscan?: number;
  /**
   * Sorting state of the table.
   */
  sorting?: SortingState;
  /**
   * Handler for detecting the scroll event. Needed for ScrollRestoration component,
   * you don't need to set it manually.
   */
  onScroll?: (event: React.UIEvent<HTMLElement>) => void;
  /**
   * Called when user scrolls to another page
   * @param pages currently visible pages
   */
  onVisiblePagesChange: (pages: number[]) => void;
  /**
   * Handler for clicking on a row.
   * @param event common MouseEvent
   * @param row item that has been clicked
   * @param index position of the item in list
   */
  onRowClick?: (event: React.MouseEvent, row: TData, index: number) => void;
  /**
   * Handler for changing the sorting state.
   */
  onSortChange?: OnChangeFn<SortingState>;
  /**
   * Enable column resizing and save info to storage
   */
  storageKeyForResizeState?: string;
  /**
   * Object that contains visibility for all columns
   */
  columnVisibility?: Record<string, boolean>;
  /**
   * ID of entity to highlight
   */
  highlightedId?: string;
};

const InfiniteTableComponent = <
  Entity extends { id: string } & Record<string, unknown>,
  TData extends TableItem<Entity>,
>(
  props: InfiniteTableProps<Entity, TData>,
  ref: ForwardedRef<InfiniteTableRef>,
) => {
  const rootRef = useRef<HTMLDivElement>(null);
  const highlightedRowRef = useRef<HTMLTableRowElement>(null);

  useImperativeHandle(ref, () => ({
    resetScroll: () => {
      rootRef.current?.scrollTo({ top: 0 });
    },
    scrollToElement: (index) => {
      rootRef.current?.scrollTo({
        behavior: 'smooth',
        top: index * (props.cellHeight || DEFAULT_CELL_HEIGHT),
      });
    },
  }));

  // calculates the width sum of the columns that should not stretch
  const filledWidth = useMemo<number>(
    () =>
      sum(
        props.columns
          .map((col) =>
            col.enableResizing === false
              ? col.maxSize
              : col.size || col.minSize,
          )
          .filter((size) => size !== Number.MAX_SAFE_INTEGER),
      ),
    [props.columns],
  );

  const [columnSizes, setColumnSizes] = useStorage(
    props.storageKeyForResizeState || '',
    {},
    { enabled: !!props.storageKeyForResizeState },
  );

  const columnsWithSizes = useMemo(
    () =>
      props.columns.map((col) => {
        const size =
          col.size === Number.MAX_SAFE_INTEGER // this will stretch the column to fill available space
            ? (rootRef.current?.offsetWidth || 0) - filledWidth
            : col.minSize || col.maxSize;
        return { ...col, size };
      }),
    [rootRef.current, props.columns], // eslint-disable-line
  );

  const table = useReactTable<TData>({
    data: props.data,
    columns: columnsWithSizes,
    getCoreRowModel: getCoreRowModel(),
    enableMultiSort: false,
    columnResizeMode: props.storageKeyForResizeState ? 'onChange' : 'onEnd',
    initialState: {
      columnSizing: props.storageKeyForResizeState ? columnSizes : {},
    },
    state: {
      sorting: props.sorting,
      columnVisibility: props.columnVisibility,
    },
    onSortingChange: props.onSortChange,
    defaultColumn: {
      size: 0,
      minSize: 100,
    },
    getRowId: (row) => row.id,
  });

  const virtualizer = useVirtualizer({
    count: props.data.length,
    getScrollElement: () => rootRef.current,
    estimateSize: () => props.cellHeight ?? DEFAULT_CELL_HEIGHT,
  });

  const minPage = props.data
    .map(({ page }) => page)
    .reduce((min, page) => Math.min(min, page), Infinity);

  const hasPrevPage = Number.isFinite(minPage) ? minPage > 1 : false;

  const infiniteScroll = useInfiniteScroll({
    height: props.cellHeight ?? DEFAULT_CELL_HEIGHT,
    perPage: props.perPage,
  });

  /**
   * Returns visible pages based on scroll position.
   * @param scrollTop
   * @param clientHeight
   */
  const updateVisiblePages = (scrollTop: number, clientHeight: number) => {
    const firstPage =
      Math.floor(
        scrollTop / (props.cellHeight ?? DEFAULT_CELL_HEIGHT) / props.perPage,
      ) + 1;

    const lastPage =
      Math.floor(
        (scrollTop + clientHeight) /
          (props.cellHeight ?? DEFAULT_CELL_HEIGHT) /
          props.perPage,
      ) + 1;

    const visiblePages = Array.from(
      { length: lastPage - firstPage + 1 },
      (_, i) => firstPage + i,
    );

    if (props.visiblePages?.join(',') !== visiblePages.join(',')) {
      const filteredVisiblePages = visiblePages.filter((page) => page >= 1);

      props.onVisiblePagesChange(filteredVisiblePages);
    }
  };

  useEffect(() => {
    if (!rootRef.current) return;

    const { scrollTop = rootRef.current.scrollTop } = props;
    const { clientHeight } = rootRef.current;

    updateVisiblePages(scrollTop, clientHeight);
    // eslint-disable-next-line
  }, []);

  useLayoutEffect(() => {
    if (rootRef.current && !props.isLoading) {
      rootRef.current.scrollTo(0, props.scrollTop || 0);
    }
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [props.isLoading]);

  useEffect(() => {
    if (!props.highlightedId) return;

    const containerRect = rootRef.current?.getBoundingClientRect();
    const highlightedRect = highlightedRowRef.current?.getBoundingClientRect();

    const isHiddenOnBottom =
      (highlightedRect?.bottom || 0) > (containerRect?.bottom || 0);
    const isHiddenOnTop =
      (highlightedRect?.top || 0) < (containerRect?.top || 0);

    if (isHiddenOnBottom || isHiddenOnTop) {
      highlightedRowRef.current?.scrollIntoView({
        behavior: 'smooth',
        block: isHiddenOnBottom ? 'end' : 'start',
      });
    }
  }, [props.highlightedId]);

  const renderSortingIcon = (
    canBeSorted: boolean,
    sortingState: 'asc' | 'desc' | false,
  ) => {
    if (!canBeSorted) {
      return null;
    }

    if (sortingState) {
      return (
        <S.Icon active>
          {sortingState === 'asc' ? <IconSortingAsc /> : <IconSortingDesc />}
        </S.Icon>
      );
    }

    return (
      <S.Icon>
        <IconSortingDefault />
      </S.Icon>
    );
  };

  const handleScroll = (event: React.UIEvent<HTMLDivElement>) => {
    const { scrollTop, clientHeight } = event.currentTarget;

    updateVisiblePages(scrollTop, clientHeight);

    props.onScroll?.(event);
  };

  const skeletons = new Array(props.perPage).fill(0);

  const { rows } = table.getRowModel();

  // calculate column sizes into variable to avoid calculating them on every cell
  const columnSizeVars = useMemo(() => {
    const headers = table.getFlatHeaders();
    return headers.reduce(
      (acc: Record<string, number | string>, header) => ({
        ...acc,
        [`--header-${header.id}-size`]: header.getSize(),
      }),
      {},
    );
  }, [props, table.getState().columnSizingInfo, table.getState().columnSizing]); // eslint-disable-line

  const getColumnWidthByCssVar = useCallback(
    (prefix: string | undefined) => `calc(var(--header-${prefix}-size) * 1px)`,
    [],
  );

  // mouseDownEvent happens on column resize handle but mouseUp event can happen anywhere
  // on the page, so we need to handle it globally
  const [listenersForColumns, setListenersForColumns] = useState<string[]>([]);
  const getMouseDownHandler = useCallback(
    (header: Header<TData, unknown>) => (event: React.MouseEvent) => {
      // prevent sorting on resize
      event.stopPropagation();

      // fix safari text selection while resizing
      document.body.style.webkitUserSelect = 'none';
      const handler = header.getResizeHandler();

      if (!listenersForColumns.includes(header.id)) {
        setListenersForColumns([...listenersForColumns, header.id]);

        document.addEventListener('mouseup', () => {
          document.body.style.webkitUserSelect = 'auto';
          setColumnSizes((sizes) => ({
            ...sizes,
            [header.id]: header.getSize(),
          }));
        });
      }

      return handler(event);
    },
    [], // eslint-disable-line
  );

  return (
    <S.Root onScroll={handleScroll} ref={rootRef}>
      <S.Table style={columnSizeVars}>
        <S.Thead>
          {table.getHeaderGroups().map((headerGroup) => (
            <S.TRow
              style={{ display: 'flex', width: '100%' }}
              key={headerGroup.id}
            >
              {headerGroup.headers.map((header, idx) => (
                <S.Th
                  key={header.id}
                  style={{
                    display: 'flex',
                    minWidth: `calc(var(--header-${header?.id}-size) * 1px)`,
                  }}
                  onMouseDown={header.column.getToggleSortingHandler()}
                >
                  <Flex flex={1} justify="space-between" align="center">
                    <Typography.Text strong>
                      {header.isPlaceholder
                        ? null
                        : flexRender(
                            header.column.columnDef.header,
                            header.getContext(),
                          )}
                    </Typography.Text>

                    {renderSortingIcon(
                      header.column.getCanSort(),
                      header.column.getIsSorted(),
                    )}

                    {props.storageKeyForResizeState &&
                      header.column.getCanResize() &&
                      idx !== headerGroup.headers.length - 1 && (
                        <S.ResizeHandle
                          onClick={(event) => event.stopPropagation()}
                          onDoubleClick={() => header.column.resetSize()}
                          onMouseDown={getMouseDownHandler(header)}
                          className={`resize-handle ${
                            header.column.getIsResizing() ? 'isResizing' : ''
                          }`}
                        />
                      )}
                  </Flex>
                </S.Th>
              ))}
            </S.TRow>
          ))}
        </S.Thead>

        <S.TBody style={{ height: virtualizer.getTotalSize() }}>
          {hasPrevPage &&
            skeletons.map((_, index) => (
              <S.TRow
                key={index}
                style={{
                  ...infiniteScroll.getStyle(
                    minPage - 1,
                    index + props.perPage - skeletons.length,
                  ),
                  width: '100%',
                }}
              >
                {table.getHeaderGroups().map((headerGroup) =>
                  headerGroup.headers.map((header) => (
                    <S.Td
                      key={header.id}
                      style={{
                        display: 'flex',
                        width: getColumnWidthByCssVar(header.id),
                        minWidth: getColumnWidthByCssVar(header.id),
                      }}
                    >
                      <S.Skeleton />
                    </S.Td>
                  )),
                )}
              </S.TRow>
            ))}

          {virtualizer.getVirtualItems().map((virtualRow) => {
            const row = rows[virtualRow.index];
            const transform = `translateY(${virtualRow.start}px)`;

            return (
              <S.TRow
                key={row.id}
                ref={props.highlightedId === row.id ? highlightedRowRef : null}
                isSelected={props.highlightedId === row.id}
                style={{
                  height: `${virtualRow.size}px`,
                  position: 'absolute',
                  transform: transform,
                  width: '100%',
                }}
                onClick={(event) =>
                  !row.original.isEmpty &&
                  props.onRowClick?.(event, row.original, row.index)
                }
              >
                {row.getVisibleCells().map((cell) => (
                  <S.Td
                    key={cell.id}
                    style={{
                      display: 'flex',
                      width: getColumnWidthByCssVar(cell?.column.id),
                      minWidth: getColumnWidthByCssVar(cell?.column.id),
                    }}
                  >
                    {row.original.isEmpty ? (
                      <S.Skeleton />
                    ) : (
                      <S.TextEllipsis>
                        {flexRender(
                          cell.column.columnDef.cell,
                          cell.getContext(),
                        )}
                      </S.TextEllipsis>
                    )}
                  </S.Td>
                ))}
              </S.TRow>
            );
          })}

          {props.loadingPages &&
            props.loadingPages?.map((page) =>
              skeletons.map((_, index) => (
                <S.TRow
                  key={index}
                  style={{
                    ...infiniteScroll.getStyle(page, index),
                    width: '100%',
                  }}
                >
                  {table.getHeaderGroups().map((headerGroup) =>
                    headerGroup.headers.map((header) => (
                      <S.Td
                        key={header.id}
                        style={{
                          width: getColumnWidthByCssVar(header.id),
                          minWidth: getColumnWidthByCssVar(header.id),
                        }}
                      >
                        <S.Skeleton />
                      </S.Td>
                    )),
                  )}
                </S.TRow>
              )),
            )}
        </S.TBody>
      </S.Table>

      {props.totalItems === 0 && !props.loadingPages?.length && (
        <Flex justify="center" align="center" style={{ height: '100%' }}>
          <Typography.Text>No data</Typography.Text>
        </Flex>
      )}
    </S.Root>
  );
};

const InfiniteTable = forwardRef(InfiniteTableComponent) as <
  Entity extends { id: string } & Record<string, unknown>,
  TData extends TableItem<Entity>,
>(
  props: InfiniteTableProps<Entity, TData> & {
    ref?: ForwardedRef<InfiniteTableRef>;
  },
  ref: ForwardedRef<InfiniteTableRef>,
) => ReturnType<typeof InfiniteTableComponent>;

export default InfiniteTable;
