import isLCMScheduler from "domains/inference/utils/isLCMScheduler";
import {
  getIsModelFlux,
  getIsModelFluxPro,
  getIsModelFluxPro1_1,
  getIsModelFluxPro1_1Ultra,
  getIsModelFluxSchnell,
  getIsModelSdxl,
} from "domains/models/utils";
import {
  GetModelsByModelIdApiResponse,
  PostModelsInferencesByModelIdApiArg,
} from "infra/api/generated/api";

import { SCHEDULER_TYPE } from "./Schedulers";

type InferenceType =
  PostModelsInferencesByModelIdApiArg["body"]["parameters"]["type"];

export const IMAGE_FOR_DIMENSIONS_BY_INFERENCE_TYPE: {
  [key in InferenceType]: "reference" | "control" | "ipAdapter";
} = {
  txt2img: "reference",
  img2img: "reference",
  controlnet: "reference",
  reference: "reference",
  inpaint: "reference",
  outpaint: "reference",
  controlnet_img2img: "control",
  controlnet_reference: "control",
  controlnet_inpaint: "control",
  txt2img_ip_adapter: "reference",
  img2img_ip_adapter: "reference",
  controlnet_ip_adapter: "reference",
  controlnet_inpaint_ip_adapter: "reference",
  inpaint_ip_adapter: "reference",
  img2img_texture: "reference",
  reference_texture: "reference",
  txt2img_texture: "reference",
  controlnet_texture: "reference",
} as const;

const defaultGetReferenceInfluence = (
  model: GetModelsByModelIdApiResponse["model"] | undefined,
  isLive: boolean
) => {
  const reference = (() => {
    if (isLive && getIsModelFlux(model)) {
      return 0.4;
    }
    return 0.25;
  })();

  const control = (() => {
    if (getIsModelSdxl(model)) {
      return 1;
    }
    return 0.5;
  })();

  return {
    reference,
    control,
    ipAdapter: 0.25,
  };
};

export const GET_REFERENCE_INFLUENCE_BY_INFERENCE_TYPE: {
  [key in InferenceType]: (
    model: GetModelsByModelIdApiResponse["model"] | undefined,
    isLive: boolean
  ) => {
    reference: number;
    control: number;
    ipAdapter: number;
  };
} = {
  img2img: (model, isLive) => {
    const reference = (() => {
      if (isLive && getIsModelFlux(model)) {
        return 0.4;
      }
      return 0.25;
    })();

    return {
      reference,
      control: 0,
      ipAdapter: 0,
    };
  },
  reference: () => ({
    reference: 0.25,
    control: 0,
    ipAdapter: 0,
  }),
  controlnet: (model) => ({
    reference: getIsModelSdxl(model) ? 1 : 0.5,
    control: 0,
    ipAdapter: 0,
  }),
  txt2img: defaultGetReferenceInfluence,
  inpaint: defaultGetReferenceInfluence,
  outpaint: defaultGetReferenceInfluence,
  controlnet_img2img: defaultGetReferenceInfluence,
  controlnet_reference: defaultGetReferenceInfluence,
  controlnet_inpaint: defaultGetReferenceInfluence,
  txt2img_ip_adapter: () => ({
    reference: 0.25,
    control: 0,
    ipAdapter: 0.25,
  }),
  img2img_ip_adapter: () => ({
    reference: 0.5,
    control: 0,
    ipAdapter: 0.25,
  }),
  controlnet_ip_adapter: () => ({
    reference: 0.5,
    control: 0,
    ipAdapter: 0.25,
  }),
  controlnet_inpaint_ip_adapter: defaultGetReferenceInfluence,
  inpaint_ip_adapter: defaultGetReferenceInfluence,
  txt2img_texture: defaultGetReferenceInfluence,
  reference_texture: defaultGetReferenceInfluence,
  img2img_texture: defaultGetReferenceInfluence,
  controlnet_texture: (model, isLive) => ({
    ...defaultGetReferenceInfluence(model, isLive),
    reference: 0.5,
  }),
};

export function getGuidanceForModel({
  model,
  scheduler,
}: {
  model: GetModelsByModelIdApiResponse["model"] | undefined;
  scheduler: SCHEDULER_TYPE | undefined;
}): {
  default: number;
  min: number;
  max: number;
} {
  const isLcm = isLCMScheduler(scheduler);

  if (getIsModelFlux(model)) {
    if (
      getIsModelFluxPro(model) ||
      getIsModelFluxPro1_1(model) ||
      getIsModelFluxPro1_1Ultra(model)
    ) {
      return {
        default: 3,
        min: 2,
        max: 5,
      };
    }
    return {
      default: 3.5,
      min: 0,
      max: 10,
    };
  }
  if (getIsModelSdxl(model)) {
    return {
      default: isLcm ? 2.5 : 6,
      min: 0,
      max: 20,
    };
  }
  return {
    default: isLcm ? 0.5 : 7,
    min: 0,
    max: 20,
  };
}

export function getSamplingStepsForModel({
  model,
  scheduler,
}: {
  model: GetModelsByModelIdApiResponse["model"] | undefined;
  scheduler: SCHEDULER_TYPE | undefined;
  isLive?: boolean;
}): {
  default: number;
  min: number;
  max: number;
} {
  const isLcm = isLCMScheduler(scheduler);

  if (
    getIsModelFluxPro(model) ||
    getIsModelFluxPro1_1(model) ||
    getIsModelFluxPro1_1Ultra(model)
  ) {
    return {
      default: 25,
      min: 1,
      max: 50,
    };
  }
  if (getIsModelFluxSchnell(model)) {
    return {
      default: 4,
      min: 1,
      max: 10,
    };
  }
  if (getIsModelFlux(model)) {
    return {
      default: 28,
      min: 1,
      max: 50,
    };
  }
  return {
    default: isLcm ? 10 : 30,
    min: isLcm ? 4 : 10,
    max: isLcm ? 15 : 150,
  };
}

export function getMaxNbImagesForModel({
  model,
}: {
  model: GetModelsByModelIdApiResponse["model"] | undefined;
}): number {
  if (getIsModelFlux(model)) {
    if (
      getIsModelFluxPro(model) ||
      getIsModelFluxPro1_1(model) ||
      getIsModelFluxPro1_1Ultra(model)
    ) {
      return 1;
    } else if (getIsModelFluxSchnell(model)) {
      return 4;
    }
    return 8;
  }
  return 16;
}
