import Decimal from "decimal.js";
import { cloneDeep, memoize } from "lodash";

import { DateOnly } from "../date_utils";
import {
  DatePairPoint,
  DatePoint,
  SBSimulationHolding,
  SimulatorError,
  SimulatorErrors,
} from "./model";
import * as resampleUtils from "./resample";
import { scaleTimeSeries } from "./resample";
import { approximateResolution, computeDeltas, deltaSquash } from "./utils";

/**
 * An abstraction over a set of time series data for a set of symbols which
 * allows determining the price of a symbol at a given date, determining the
 * next valid date to simulate, and determining the start/end dates of the
 * available data.
 */
// DEPRECATED: use AsyncSymbolPriceTimeSeries instead
export type SymbolPriceTimeSeries = {
  lookup: (args: { date: DateOnly; symbol: string }) => Decimal;
  nextDate: (date: DateOnly) => DateOnly | undefined;
  dataStart: DateOnly;
  dataEnd: DateOnly;
};

// Prefer this one
export type AsyncSymbolPriceTimeSeries = {
  lookup: (args: {
    date: DateOnly;
    symbol: string;
  }) => Promise<Decimal | undefined>;
  nextDate: (date: DateOnly) => Promise<DateOnly | undefined>;
  dataStart: DateOnly;
  dataEnd: DateOnly;
  hasData: (date: DateOnly) => Promise<boolean>;
};

export class SymbolPriceTimeSeriesFactory {
  /**
   * @param returnMultiplier          Beta values for each supported symbol, used for price returns
   * @param volatilityMultiplier      Beta values for each supported symbol, used for price volatility
   * @param seedPrices          Prices for the seed stock data (i.e. SPX500)
   * @param holdings            Initial holdings to use for the simulation
   * @param simulationStart     The start date of the simulation. Must come after the first date in the seed data
   * @param simulationEnd       The end date of the simulation. Must come after the first date in the seed data
   * @param resample            If {targetReturn} provided, we will resample the seed data to match the target return.
   *                             @targetReturn represents the average yearly return over the time period. Examples:
   *                             -- 5% is represented as 0.05
   *                             -- -10% is represented as -0.1
   */
  public static fromSeedData(
    returnMultiplier: (date: DateOnly, symbol: string) => number,
    volatilityMultiplier: (date: DateOnly, symbol: string) => number,
    seedPrices: DatePoint[],
    holdings: SBSimulationHolding[],
    simulationStart: DateOnly,
    simulationEnd: DateOnly,
    resample?: {
      targetReturn: number; // TODO: update to Decimal for consistency
    }
  ): SymbolPriceTimeSeries {
    if ((resample?.targetReturn ?? 1) <= -1) {
      throw new SimulatorError(
        SimulatorErrors.InvalidInput,
        "resample.targetReturn must be greater than -1"
      );
    }

    const holdingsClone = cloneDeep(holdings);

    // start on business days
    simulationStart = simulationStart.thisOrNextBusinessDay();

    // Generate or calculate the deltas for the seed data, which will be used
    // to generate the future prices for each stock.
    const duration = simulationEnd.diff(simulationStart, "years");
    const finalTarget = Math.pow(1 + (resample?.targetReturn ?? 0), duration);
    const simulationDeltas = this.getSeedDeltas(
      seedPrices,
      simulationStart,
      simulationEnd,
      resample?.targetReturn ? finalTarget : undefined
    );

    const firstDate = simulationDeltas[0].startDate;

    // For each stock symbol compute the future prices for the same
    // dates as the resampled deltas
    const stockData: Map<string, DatePoint[]> = new Map();

    // Start with the initial prices
    holdingsClone.forEach((holding) => {
      stockData.set(holding.symbol, [
        { date: firstDate, value: holding.price.toNumber() },
      ]);
    });

    // Compute the future prices for each stock
    const seedReturn = simulationDeltas.reduce((a, b) => a * (b.value + 1), 1);
    const seedGrowthRate = Math.pow(seedReturn, 1 / duration) - 1;
    for (const holding of holdingsClone) {
      // multiply the deltas by the beta to simulate volatility
      // minor tweak to betas, check the comments of deltaSquash for more info
      const scaledDeltas = simulationDeltas.map((d) => ({
        ...d,
        value: deltaSquash(
          d.value,
          volatilityMultiplier(d.startDate, holding.symbol)
        ),
      }));

      // We want to scale the yearly growth rate for a stock, but the multiplier changes.
      // So we take the average multiplier over the time period and use that to scale the
      // average return.
      const m: number[] = [];
      simulationDeltas.forEach((d) => {
        m.push(returnMultiplier(d.startDate, holding.symbol));
      });
      const multiplier = m.reduce((a, b) => a + b, 0) / m.length;

      // Scale the total returns for this stock as per modeling assumptions
      const stockTotalReturn = Math.pow(
        seedGrowthRate * multiplier + 1,
        duration
      );
      const finalScale = scaleTimeSeries(scaledDeltas, stockTotalReturn);

      // Step through each date in the delta time series and
      // compute the new price for each of the stocks.
      const data = stockData.get(holding.symbol);
      const startPrice = holding.price.toNumber();
      let cumDelta = 1;
      finalScale.forEach(({ endDate, value: deltaValue }) => {
        cumDelta *= deltaValue + 1;
        data?.push({ date: endDate, value: cumDelta * startPrice });
      });
    }

    // TODO: A more accurate implementation would interpolate between the dates.
    const memoizeLookup: (args: { date: DateOnly; symbol: string }) => Decimal =
      memoize(
        ({ date, symbol }) => {
          const data = stockData.get(symbol);
          if (!data) return new Decimal(0);

          // Get the price at the largest date that is smaller than the date passed in
          const closestDatePoint =
            data.filter((m) => m.date.valueOf() <= date.valueOf()).at(-1) ??
            data[0];

          return new Decimal(closestDatePoint.value);
        },
        ({ date, symbol }) => `${date.toString()}-${symbol}`
      );

    // exact simulation end date has to fall on the correct resolution
    simulationEnd = simulationDeltas.at(-1)?.endDate ?? simulationEnd;

    // The nextDate should work for any date, not just the dates in the simulationPrices
    const memoizeNextDate: (date: DateOnly) => DateOnly | undefined = memoize(
      (previousDate: DateOnly) => {
        // check if simulation is over
        if (previousDate >= simulationEnd) {
          return undefined;
        }
        // If the previous date is before the simulation start, return the simulation start `endDate`
        if (previousDate < simulationDeltas[0].startDate) {
          return simulationDeltas[0].endDate;
        }

        // find the largest date that is lte the previous date, then grab the `endDate` of that delta
        const d = simulationDeltas
          .filter(
            (delta) => delta.startDate.valueOf() <= previousDate.valueOf()
          )
          .at(-1);
        return d?.endDate;
      },
      (previousDate: DateOnly) => previousDate.toString()
    );

    return {
      lookup: memoizeLookup,
      dataStart: simulationStart,
      dataEnd: simulationEnd,
      nextDate: memoizeNextDate,
    };
  }

  /**
   * A primary component for this model is the vector of changes for the seed index
   * for the duration of the simulation. So, depending on whether the simulation is
   * sampled (e.g. for a future simulation) or a historical simulation, we generate
   * the deltas and return them.
   *
   * Passing a `targetReturn` will cause the deltas to be sampled, and will match the
   * desired return.
   */
  private static getSeedDeltas(
    seedPrices: DatePoint[],
    simulationStart: DateOnly,
    simulationEnd: DateOnly,
    targetReturn?: number
  ): DatePairPoint[] {
    /*
     * Data from which we will calculate returns for stocks in the portfolio
     * using delta * beta method.
     */
    seedPrices = [...seedPrices].sort((a, b) =>
      DateOnly.compare(a.date, b.date)
    );
    const seedDeltas = computeDeltas(seedPrices);

    let simulationDeltas: DatePairPoint[];

    // If a target return is specified, we will sample a new set of deltas, otherwise
    // we will use the deltas from the (historical) seed prices.
    if (targetReturn !== undefined) {
      // First construct the dates for the resampled deltas
      const seedResolution = this.getApproximateResolution(seedPrices);
      const dates = [simulationStart];
      let nextDate = simulationStart;
      do {
        nextDate = nextDate.add(1, seedResolution);
        dates.push(nextDate.thisOrNextBusinessDay()); // use business days
      } while (nextDate < simulationEnd);

      const sampledDeltas = resampleUtils.resampleDeltas(
        seedDeltas.map((d) => d.value),
        dates.length - 1, // we don't need a delta for the last date
        targetReturn
      );

      simulationDeltas = sampledDeltas.map((value, i) => ({
        startDate: dates[i],
        endDate: dates[i + 1],
        value,
      }));
      simulationDeltas.sort((a, b) =>
        DateOnly.compare(a.startDate, b.startDate)
      );
    } else {
      /**
       * If we are not resampling, we need only sub-sample the seed deltas to the
       * required range.
       */
      // check start and end dates are within the seed data
      if (
        simulationStart < seedPrices[0].date ||
        simulationEnd > seedPrices[seedPrices.length - 1].date
      ) {
        throw new Error(
          `Simulation start and end dates must be within the seed data range when not resampling: ${
            seedPrices[0].date
          } - ${seedPrices[seedPrices.length - 1].date}`
        );
      }
      simulationDeltas = seedDeltas.filter(
        (d) => d.startDate >= simulationStart && d.endDate <= simulationEnd
      );
    }

    return simulationDeltas;
  }

  /**
   * TODO: for some reason, the data passed in (from the historicalDeltas query) has started returning a single
   * 3 day duration
   *
   * See: https://frec.atlassian.net/browse/FREC-3233?atlOrigin=eyJpIjoiOTE1NjY0ZDNiMzZlNGU0M2IzNjM2NDI0N2MyYmJlMTQiLCJwIjoiamlyYS1zbGFjay1pbnQifQ
   */
  private static getApproximateResolution(
    seedPrices: DatePoint[]
  ): NonNullable<ReturnType<typeof approximateResolution>> {
    return approximateResolution(seedPrices) ?? "month";
  }
}
