import { curry } from 'ramda'
import { AggregateCoord } from '../../../benchmark/model/benchmark-chart'
import { BenchmarkChartView } from '../../../benchmark/model/benchmark-chart-view'
import { BenchmarkPeerSet } from '../../../benchmark/model/benchmark.model'
import { GraphingLineDatum } from '@graphing/models/graphing.model'
import { calculateLinearLeastSquaresCoords } from '@graphing/utils/linear-least-squares'
import { toTruthMap } from '@shared/util/object'
import { rejectNil } from '@shared/util/operators'
import { makeDataAggregator } from './aggregate-data'

function getDataPeerSetInfo(
  data: AggregateCoord[],
  peerSets?: BenchmarkPeerSet[]
) {
  const peerIDsByPeerSetID = (peerSets ?? []).reduce((acc, peerSet) => {
    acc[peerSet.id] = peerSet.peers.map(p => p.id)
    return acc
  }, {} as Record<string | number, (string | number)[]>)

  const { dataIndexByID, xMin, xMax } = data.reduce(
    (acc, d, i) => {
      if (d.x < acc.xMin) {
        acc.xMin = d.x
      }
      if (d.x > acc.xMax) {
        acc.xMax = d.x
      }
      acc.dataIndexByID[d.companyID ?? d.peerSetID ?? '__na'] = i
      return acc
    },
    { dataIndexByID: {}, xMin: Infinity, xMax: -Infinity } as {
      dataIndexByID: Record<string | number, number>
      xMin: number
      xMax: number
    }
  )

  return { peerIDsByPeerSetID, dataIndexByID, xMin, xMax }
}

export interface CalculateAggregateRegressionLinesOptions {
  rejectEntityIDs?: (string | number)[]
  skipRegression?: boolean
  peerSets?: BenchmarkPeerSet[]
  includeCompositeRegressions?: boolean
  hideSinglePeerRegressions?: boolean
}

export function calculateAggregateRegressionLines<T extends AggregateCoord>(
  data: T[],
  options: CalculateAggregateRegressionLinesOptions = {}
): AggregateCoord[][] {
  const regressionLines: AggregateCoord[][] = []

  // Remove the target company from any regression lines
  let series = data
  if (options.rejectEntityIDs) {
    const rejectEntityIDsMap = toTruthMap(options.rejectEntityIDs)
    series = data.filter(
      s => s.companyID != null && !rejectEntityIDsMap[s.companyID]
    )
  }

  let companyID = 'regression'
  let companyName = 'Regression'

  if (!options.skipRegression) {
    // Calculate overall regression line using linear least squares
    const coords = calculateLinearLeastSquaresCoords(series, {
      xAccessor: d => d.x,
      yAccessor: d => d.y,
    })
    companyID = 'regression'
    companyName = 'Regression'
    const seriesLine = coords.map(c => ({ ...c, companyID, companyName }))
    regressionLines.push(seriesLine)
  }

  if (!options.peerSets || options.peerSets.length === 0) {
    return regressionLines
  }

  const { includeCompositeRegressions, hideSinglePeerRegressions } = options
  // Get mapping data of peers to the corresponding data series index
  const dataPeerSetInfo = getDataPeerSetInfo(series, options.peerSets)
  const { dataIndexByID, peerIDsByPeerSetID, xMin, xMax } = dataPeerSetInfo
  companyID = 'regressionPeerSets'
  companyName = 'Peer Group Regressions'

  return options.peerSets.reduce((acc, ps) => {
    const peerSetID = ps.id
    const peerSetName = ps.name
    // Get peers within this peer set
    const peerIDs = peerIDsByPeerSetID[ps.id]
    if ((!ps.isComposite || includeCompositeRegressions) && peerIDs?.length) {
      // Map each peers to its corresponding data in the series, removing
      // any without data
      const dataPerPeer = rejectNil(
        peerIDs.map(id => series[dataIndexByID[id]])
      )

      if (hideSinglePeerRegressions && dataPerPeer.length === 1) {
        return acc
      }

      const peerSetCoords = calculateLinearLeastSquaresCoords(dataPerPeer, {
        xAccessor: d => d?.x ?? 0,
        yAccessor: d => d?.y ?? 0,
        xMin,
        xMax,
      })
      const line = peerSetCoords.map(
        (c): AggregateCoord => ({
          ...c,
          companyID,
          companyName,
          peerSetID,
          peerSetName,
        })
      )
      acc.push(line)
    }
    return acc
  }, regressionLines)
}

export const calculateAggregateYLines = curry(
  (view: BenchmarkChartView, series: AggregateCoord[]): GraphingLineDatum[] => {
    const yAccessor = (d: AggregateCoord) => d.y
    const aggregate = makeDataAggregator('median', yAccessor)

    const dataPeerSetInfo = getDataPeerSetInfo(series, view?.peerSets)
    const { dataIndexByID, peerIDsByPeerSetID } = dataPeerSetInfo

    const yLines: GraphingLineDatum[] = [
      { id: 'median', name: 'Median', value: aggregate(series) },
    ]

    // For each peer set, calculate median
    return (view.peerSets ?? []).reduce((acc, ps) => {
      // Get peers within this peer set
      const peerIDs = peerIDsByPeerSetID[ps.id]
      if (peerIDs?.length) {
        const dataPerPeer = peerIDs.map(id => series[dataIndexByID[id]])

        // Aggregate mean value within peers
        const value = aggregate(rejectNil(dataPerPeer))
        acc.push({ id: ps.id, name: ps.name, value })
      }
      return acc
    }, yLines)
  }
)
