import { curry } from 'ramda'
import { Coord } from './coord'

export interface LinearLeastSquaresOptions<T> {
  xAccessor: (d: T) => number
  yAccessor: (d: T) => number
  xMin?: number
  xMax?: number
}

interface DataSums {
  x: number
  y: number
  xx: number
  xy: number
}

interface LeastSquaresLine {
  slope: number
  intersection: number
  xMin: number
  xMax: number
}

const calculateCoord = curry(
  (slope: number, intersection: number, x: number): Coord => ({
    x,
    y: x * slope + intersection,
  })
)

export const calculateLinearLeastSquaresCoords = <T>(
  data: T[],
  options: LinearLeastSquaresOptions<T>
): readonly [Coord, Coord] => {
  const line = calculateLinearLeastSquaresLine(data, options)
  const { slope, intersection, xMax, xMin } = line

  const coord = calculateCoord(slope, intersection)
  return [coord(xMin), coord(xMax)] as const
}

export const calculateLinearLeastSquaresLine = <T>(
  data: T[],
  options: LinearLeastSquaresOptions<T>
): LeastSquaresLine => {
  const { xAccessor, yAccessor } = options
  const n = data.length

  if (data.length === 0) {
    const xMin = options.xMin ?? 0
    const xMax = options.xMax ?? 0
    return { slope: 0, intersection: 0, xMin, xMax }
  } else if (data.length === 1) {
    const xMin = options.xMin ?? xAccessor(data[0])
    const xMax = options.xMax ?? xAccessor(data[0])
    return { slope: 0, intersection: yAccessor(data[0]), xMin, xMax }
  } else {
    const isMinMaxSet = options.xMin != null && options.xMax != null
    let xMin = options.xMin ?? Infinity
    let xMax = options.xMax ?? -Infinity

    const sum: DataSums = data.reduce(
      (acc, d) => {
        const x = xAccessor(d)
        const y = yAccessor(d)

        acc.x += x
        acc.y += y
        acc.xx += x * x
        acc.xy += x * y

        if (!isMinMaxSet) {
          if (x < xMin) {
            xMin = x
          }
          if (x > xMax) {
            xMax = x
          }
        }

        return acc
      },
      { x: 0, y: 0, xx: 0, xy: 0 }
    )

    const slope = (sum.xy * n - sum.x * sum.y) / (sum.xx * n - sum.x * sum.x)

    const intersection = sum.y / n - (sum.x * slope) / n

    return { slope, intersection, xMin, xMax }
  }
}
