import { useMediaQuery, ListSubheader, useTheme } from '@material-ui/core';
import type { AutocompleteRenderGroupParams } from '@material-ui/lab/Autocomplete';
import type { ReactNode } from 'react';
import {
  cloneElement,
  createContext,
  forwardRef,
  useContext,
  useRef,
  useEffect,
  Children,
  isValidElement,
  useMemo,
} from 'react';
import type { ListChildComponentProps } from 'react-window';
import { VariableSizeList } from 'react-window';

const listboxPadding = 8; // px

function renderRow(props: ListChildComponentProps) {
  const { data, index, style } = props;
  return cloneElement(data[index], {
    style: {
      ...style,
      top: (style.top as number) + listboxPadding,
    },
  });
}

// eslint-disable-next-line @typescript-eslint/naming-convention
const OuterElementContext = createContext({});

const OuterElementType = forwardRef<HTMLDivElement>(function OuterElementType(
  props,
  ref
) {
  const outerProps = useContext(OuterElementContext);
  return <div ref={ref} {...props} {...outerProps} />;
});

export const ListboxComponent = forwardRef<HTMLDivElement>(
  function ListboxComponent(props, ref) {
    const { children, ...other } = props;
    const itemData = Children.toArray(children);
    const gridRef = useRef<VariableSizeList>(null);
    const theme = useTheme();
    const smUp = useMediaQuery(theme.breakpoints.up('sm'), { noSsr: true });
    const itemCount = itemData.length;
    const itemSize = smUp ? 36 : 48;

    const selectedItemIndex = useMemo(
      () =>
        itemData.findIndex(
          (child: ReactNode) =>
            isValidElement(child) && child?.props?.['aria-selected']
        ),
      [itemData]
    );

    useEffect(() => {
      if (selectedItemIndex) {
        gridRef.current?.scrollToItem(selectedItemIndex, 'center');
      }
    }, [selectedItemIndex]);

    useEffect(() => {
      gridRef.current?.resetAfterIndex(0, true);
    }, [itemCount]);

    const getChildSize = (child: ReactNode) => {
      if (isValidElement(child) && child.type === ListSubheader) {
        return 48;
      }

      return itemSize;
    };

    const getHeight = () => {
      if (itemCount > 8) {
        return 8 * itemSize;
      }
      return itemData.map(getChildSize).reduce((a, b) => a + b, 0);
    };

    return (
      <div ref={ref}>
        <OuterElementContext.Provider value={other}>
          <VariableSizeList
            itemData={itemData}
            height={getHeight() + 2 * listboxPadding}
            width="100%"
            ref={gridRef}
            outerElementType={OuterElementType}
            innerElementType="ul"
            itemSize={(index) => getChildSize(itemData[index])}
            overscanCount={5}
            itemCount={itemCount}
          >
            {renderRow}
          </VariableSizeList>
        </OuterElementContext.Provider>
      </div>
    );
  }
);

export const renderGroup = (params: AutocompleteRenderGroupParams) => [
  <ListSubheader key={params.key} component="div">
    {params.group}
  </ListSubheader>,
  params.children,
];
