import { getLearningRateValue } from "domains/models/constants/train";
import { TrainForm } from "domains/models/interfaces/train";

export function getDefaultTrainingSteps(form: TrainForm, datasetSize: number) {
  const base = form?.type === "flux.1-lora" ? 500 : 0;
  const multiple = (() => {
    if (form?.type === "sd-1_5" && form?.flow === "unguided") {
      return 100;
    } else if (form?.type === "flux.1-lora") {
      return 50;
    } else {
      return 350;
    }
  })();
  return base + datasetSize * multiple;
}

export function getTrainingParameters(
  form: TrainForm,
  datasetSize: number,
  { isBeforeTrain }: { isBeforeTrain?: boolean } = {}
) {
  return {
    dummyModel: {
      "sd-xl-lora": "dummy_sd-xl-lora",
      "sd-1_5": "dummy_sd-1_5",
      "flux.1-lora": "dummy_flux.1-lora",
    }[form.type],
    body: {
      name: form.name,
      type: form.type,
      classSlug: (() => {
        if (form.type === "sd-1_5" && form.flow === "guided") {
          if (form.classSlug === "default") {
            return isBeforeTrain ? "art-style-illustration" : null;
          } else if (form.classSlug) {
            return form.classSlug;
          } else {
            return null;
          }
        } else {
          return null;
        }
      })(),
      parameters: {
        conceptPrompt: (() => {
          if (form.type === "sd-xl-lora" && form.preset === "style") {
            return "daiton style";
          } else if (form.type === "sd-xl-lora") {
            return "daiton";
          } else if (form.type === "flux.1-lora") {
            return "";
          } else if (form.type === "sd-1_5" && form.flow === "unguided") {
            return undefined;
          } else {
            return null;
          }
        })(),
        maxTrainSteps:
          form.totalTrainingSteps === 0 && isBeforeTrain
            ? getDefaultTrainingSteps(form, datasetSize)
            : form.totalTrainingSteps,

        ...(["sd-xl-lora", "sd-1_5"].includes(form.type)
          ? {
              learningRateUnet: getLearningRateValue(form.learningRateUnet),
              textEncoderTrainingRatio:
                form.textEncoderTrainingRatio || undefined,
              learningRateTextEncoder: getLearningRateValue(
                form.textEncoderLearningRate
              ),
            }
          : {}),

        ...(form.type === "flux.1-lora"
          ? {
              learningRate: getLearningRateValue(form.learningRate),
              optimizeFor: form.isOptimizedForLikeness
                ? ("likeness" as const)
                : null,
            }
          : {}),
      },
    },
  };
}
