import {
  GetJobsApiResponse,
  GetModelsByModelIdApiResponse,
} from "infra/api/generated/api";
import { apiSlice } from "infra/store/apiSlice";
import { API_TAGS } from "infra/store/constants";
import { AppRootState } from "infra/store/store";
import moment from "moment";

import { AnyAction, ThunkDispatch } from "@reduxjs/toolkit";

export function updateQueryDataByUpdatingModelStatus({
  job,
  dispatch,
  getState,
}: {
  job: GetJobsApiResponse["jobs"][number];
  dispatch: ThunkDispatch<any, any, AnyAction>;
  getState: () => AppRootState;
}) {
  const modelId = (job.metadata as any)?.modelId;

  if (
    !["success", "canceled", "failure"].includes(job.status) ||
    moment(job.updatedAt).isBefore(moment().subtract(5, "minutes")) ||
    !modelId
  ) {
    return;
  }

  const jobInput = job.metadata?.input as any;

  if (job.jobType === "upload" && jobInput?.uploadType.includes("model")) {
    let isModelInList = false;
    const invalidatedEndpoints = apiSlice.util.selectInvalidatedBy(getState(), [
      API_TAGS.model,
    ]);
    for (const { endpointName, originalArgs } of invalidatedEndpoints) {
      if (endpointName === "getModels") {
        dispatch(
          apiSlice.util.updateQueryData(endpointName, originalArgs, (draft) => {
            const model = draft?.models?.find((model) => model.id === modelId);
            if (model) {
              isModelInList = true;
            }
          })
        );
        if (isModelInList) break;
      }
    }
    if (!isModelInList) {
      dispatch(apiSlice.util.invalidateTags([API_TAGS.model]));
    }
  } else if (
    job.jobType === "flux-model-training" ||
    job.jobType === "model-training"
  ) {
    const jobStatusToModelStatus: {
      [key in GetJobsApiResponse["jobs"][number]["status"]]:
        | GetModelsByModelIdApiResponse["model"]["status"]
        | undefined;
    } = {
      success: "trained",
      canceled: "training-canceled",
      failure: "failed",
      "in-progress": undefined,
      queued: undefined,
      "warming-up": undefined,
    };

    const newModelStatus = jobStatusToModelStatus[job.status];
    if (!newModelStatus) return;

    const invalidatedEndpoints = apiSlice.util.selectInvalidatedBy(getState(), [
      API_TAGS.model,
    ]);
    for (const { endpointName, originalArgs } of invalidatedEndpoints) {
      if (endpointName === "getModels") {
        dispatch(
          apiSlice.util.updateQueryData(endpointName, originalArgs, (draft) => {
            const modelToUpdate = draft?.models?.find(
              (model) => model.id === modelId
            );
            if (modelToUpdate) {
              modelToUpdate.status = newModelStatus;
            }
          })
        );
      }
      if (endpointName === "getModelsByModelId") {
        dispatch(
          apiSlice.util.updateQueryData(endpointName, originalArgs, (draft) => {
            if (draft?.model && draft.model.id === modelId) {
              draft.model.status = newModelStatus;
            }
          })
        );
      }
    }
  }
}
