import groupBy from 'lodash/fp/groupBy';
import isNil from 'lodash/fp/isNil';
import { type BigNumber, bignumber } from 'mathjs';
import type { ReactElement } from 'react';
import { NO_DATA_SUNBURST_ITEM } from 'components/technical/charts/SunburstChart/NoDataSunburstItem.ts';
import StaticAssetGroupSunburst, {
  type StaticAssetGroupSunburstProps,
} from 'components/technical/charts/SunburstChart/StaticAssetGroupSunburst';
import type { SunburstChartData } from 'components/technical/charts/SunburstChart/SunburstChart.props';
import {
  calculateChartInput,
  type ElementChild,
  type ElementTree,
  isChildElement,
} from 'components/technical/charts/SunburstChart/SunburstChart.utils';
import type { Aggregation, AggregationValue } from './PositionAggregationsService.ts';
import bigNumMath from '../../../bigNumMath';
import { getLightChartColor, getRegularChartColor } from '../../../theme/colors';
import { formatCash } from '../../formatter.utils';
import { mapValues, uniqBy } from 'lodash/fp';
import { useTheme } from '@mui/joy';
import { useFinalColorScheme } from '../../../useFinalColorScheme.ts';
import sortBy from 'lodash/fp/sortBy';
import { useGenerateKeyOnValueChanged } from '../../UseGenerateKeyOnValueChanged.tsx';

type ChartChild = ElementChild & { value: BigNumber };

const calculateColorForLevel = (colorScheme: 'dark' | 'light', level: number, index: number): string => {
  const normalizedLevel = level % 3;
  if (normalizedLevel === 0) {
    return getRegularChartColor(colorScheme, index);
  }

  if (normalizedLevel === 1) {
    return getLightChartColor(colorScheme, index);
  }

  return getRegularChartColor(colorScheme, index);
};

const balanceOption = { label: 'Balance', category: 'balance' };
const exposureOption = { label: 'Exposure', category: 'exposure' };

const balanceExposureDimensionDefaultOptions: StaticAssetGroupSunburstProps['options'] = [
  balanceOption,
  exposureOption,
];

export type PortfolioSnapshotSunburstAsset = { asset: { id: string; symbol: string } };

const sortByValueCoeffDesc = <T extends Omit<ChartChild, 'color'>>(
  items: (ElementTree<T> & { valueCoeff: BigNumber })[]
): (ElementTree<T> & { valueCoeff: BigNumber })[] => {
  return items.toSorted((a, b) => b.valueCoeff.toNumber() - a.valueCoeff.toNumber());
};

const calculateGroups = <TPosition,>(
  positions: Array<TPosition & PortfolioSnapshotSunburstAsset>,
  aggregations: Record<string, Aggregation<TPosition>>,
  aggregationOrder: string[],
  calculateValue: (position: TPosition) => BigNumber | undefined
): { elements: ElementTree<ChartChild>[]; value: BigNumber; valueCoeff: BigNumber } => {
  if (positions.length === 0) {
    return { elements: [], value: bignumber(0), valueCoeff: bignumber(0) };
  }

  if (aggregationOrder.length === 0) {
    return {
      elements: [],
      value: bigNumMath.sum(positions.map((pos) => calculateValue(pos) ?? bignumber(0))),
      valueCoeff: bigNumMath.sum(positions.map((pos) => calculateValue(pos)?.abs() ?? bignumber(0))),
    };
  }

  const groupingLayerId = aggregationOrder[0];
  const groupingLayer = aggregations[groupingLayerId];
  const elements = Object.entries(groupBy((pos) => groupingLayer.calculateValue(pos).id, positions)).map(
    ([_categoryValueId, categoryPositions]) => {
      const groups = calculateGroups(categoryPositions, aggregations, aggregationOrder.slice(1), calculateValue);

      return {
        label: groupingLayer.calculateValue(categoryPositions[0]).label,
        elements: groups.elements,
        value: groups.value,
        valueCoeff: groups.valueCoeff,
      };
    }
  );

  return {
    elements: sortByValueCoeffDesc(elements),
    value: bigNumMath.sum(elements.map((group) => group.value)),
    valueCoeff: bigNumMath.sum(elements.map((group) => group.valueCoeff)),
  };
};

const assignColors = (
  elements: ElementTree<ChartChild>[],
  colorScheme: 'dark' | 'light',
  level: number,
  parentIndex: number
): ElementTree<ChartChild>[] => {
  return elements.map((el, idx) => {
    const color = calculateColorForLevel(colorScheme, level, parentIndex + idx);
    if (isChildElement(el)) {
      return {
        ...el,
        color,
      };
    }

    return {
      ...el,
      color,
      elements: assignColors(el.elements, colorScheme, level + 1, parentIndex + idx),
    };
  });
};

type PortfolioSnapshotSunburstProps<TPosition> = {
  loaded?: boolean;
  Fallback?: () => ReactElement;
  aggregationOrder: string[] | undefined;
  positions: Array<PortfolioSnapshotSunburstAsset & TPosition>;
  calculateBalance: (position: TPosition) => BigNumber | undefined;
  calculateExposure?: (position: TPosition) => BigNumber | undefined;
  aggregationsByCategory: Record<string, Aggregation<TPosition>> | undefined;
};

const PortfolioSnapshotSunburst = <TPosition,>({
  aggregationsByCategory,
  positions,
  loaded,
  Fallback,
  calculateExposure,
  aggregationOrder,
  calculateBalance,
}: PortfolioSnapshotSunburstProps<TPosition>): ReactElement => {
  const colorScheme = useFinalColorScheme();
  const theme = useTheme();

  const groupingLayerValues = mapValues((group) => {
    const groupValues = positions.map((pos) => group.calculateValue(pos));
    return sortBy(
      (item: AggregationValue) => item.label,
      uniqBy((item) => item.id, groupValues)
    );
  }, aggregationsByCategory);

  // recreate chart when aggregation order changes - parent id in a suburst are different
  const key = useGenerateKeyOnValueChanged(aggregationOrder);
  return (
    <StaticAssetGroupSunburst
      options={balanceExposureDimensionDefaultOptions}
      loaded={loaded}
      Fallback={Fallback}
      key={key}
      hideCategorySelection={isNil(calculateExposure)}
      calculateChartData={(balanceOrExposure): Omit<SunburstChartData, 'hoverinfo' | 'textinfo'> => {
        const calculateValue = (pos: TPosition): BigNumber | undefined =>
          balanceOrExposure === balanceOption.category ? calculateBalance(pos) : calculateExposure?.(pos);

        if (positions.length === 0 || !aggregationsByCategory || !groupingLayerValues || !aggregationOrder) {
          return NO_DATA_SUNBURST_ITEM;
        }

        const totalValue = bigNumMath.sum(positions.map((pos) => calculateValue(pos)?.abs() ?? bignumber(0)));
        // we need on one hand quite good precision,
        // because we want to show accurate total value, on the other hand we cannot show everything, because it crashes plotly
        const minValue = 0.000001 * totalValue.toNumber(); // show min 0.0001%
        // plotly crashes for small asset amounts
        const filteredPos = positions.filter((pos) => {
          const value = calculateValue(pos)?.abs();
          if (isNil(value)) {
            return false;
          }

          return value?.greaterThan(minValue);
        });

        const root: ElementTree<ChartChild> = {
          label: 'Total',
          color: theme.palette.neutral['100'],
          elements: assignColors(
            calculateGroups(filteredPos, aggregationsByCategory, aggregationOrder, (child) =>
              balanceOrExposure === balanceOption.category ? calculateBalance(child) : calculateExposure?.(child)
            ).elements,
            colorScheme,
            0,
            0
          ),
        };

        return calculateChartInput({
          root,
          valueProvider: (val) => bignumber(val.value),
          textProvider: (value) => formatCash(value),
          rootValue: totalValue,
        });
      }}
    />
  );
};

export default PortfolioSnapshotSunburst;
