import { TrainForm } from "domains/models/interfaces/train";

export const FORM_TYPE = ["sd-xl-lora", "flux.1-lora", "sd-1_5"] as const;
export type FormType = (typeof FORM_TYPE)[number];

type LearningRates = {
  hash: { [key: string]: { label: string; value: number } };
  lists: {
    global: { "flux.1": string[] };
    unet: { "sd-1_5": string[]; "sd-xl": string[] };
    textEncoder: { "sd-1_5": string[]; "sd-xl": string[] };
  };
};

export const LEARNING_RATES: LearningRates = {
  hash: {
    "1e-3": { label: "1e-3", value: 1e-3 },
    "9e-4": { label: "9e-4", value: 9e-4 },
    "8e-4": { label: "8e-4", value: 8e-4 },
    "7e-4": { label: "7e-4", value: 7e-4 },
    "6e-4": { label: "6e-4", value: 6e-4 },
    "5e-4": { label: "5e-4", value: 5e-4 },
    "4e-4": { label: "4e-4", value: 4e-4 },
    "3e-4": { label: "3e-4", value: 3e-4 },
    "2e-4": { label: "2e-4", value: 2e-4 },
    "1e-4": { label: "1e-4", value: 1e-4 },
    "9e-5": { label: "9e-5", value: 9e-5 },
    "8e-5": { label: "8e-5", value: 8e-5 },
    "7e-5": { label: "7e-5", value: 7e-5 },
    "6e-5": { label: "6e-5", value: 6e-5 },
    "5e-5": { label: "5e-5", value: 5e-5 },
    "4e-5": { label: "4e-5", value: 4e-5 },
    "3e-5": { label: "3e-5", value: 3e-5 },
    "2e-5": { label: "2e-5", value: 2e-5 },
    "1e-5": { label: "1e-5", value: 1e-5 },
    "9e-6": { label: "9e-6", value: 9e-6 },
    "8e-6": { label: "8e-6", value: 8e-6 },
    "7e-6": { label: "7e-6", value: 7e-6 },
    "6e-6": { label: "6e-6", value: 6e-6 },
    "5e-6": { label: "5e-6", value: 5e-6 },
    "4e-6": { label: "4e-6", value: 4e-6 },
    "3e-6": { label: "3e-6", value: 3e-6 },
    "2e-6": { label: "2e-6", value: 2e-6 },
    "1e-6": { label: "1e-6", value: 1e-6 },
    "9e-7": { label: "9e-7", value: 9e-7 },
    "8e-7": { label: "8e-7", value: 8e-7 },
    "7e-7": { label: "7e-7", value: 7e-7 },
    "6e-7": { label: "6e-7", value: 6e-7 },
    "5e-7": { label: "5e-7", value: 5e-7 },
    "4e-7": { label: "4e-7", value: 4e-7 },
    "3e-7": { label: "3e-7", value: 3e-7 },
    "2e-7": { label: "2e-7", value: 2e-7 },
    "1e-7": { label: "1e-7", value: 1e-7 },
  },
  lists: {
    global: {
      "flux.1": [
        "1e-5",
        "2e-5",
        "3e-5",
        "4e-5",
        "5e-5",
        "6e-5",
        "7e-5",
        "8e-5",
        "9e-5",
        "1e-4",
        "2e-4",
        "3e-4",
        "4e-4",
        "5e-4",
        "6e-4",
        "7e-4",
        "8e-4",
        "9e-4",
        "1e-3",
      ],
    },
    unet: {
      "sd-1_5": [
        "1e-6",
        "2e-6",
        "3e-6",
        "4e-6",
        "5e-6",
        "6e-6",
        "7e-6",
        "8e-6",
        "9e-6",
        "1e-5",
      ],
      "sd-xl": [
        "1e-7",
        "2e-7",
        "3e-7",
        "4e-7",
        "5e-7",
        "6e-7",
        "7e-7",
        "8e-7",
        "9e-7",
        "1e-6",
        "2e-6",
        "3e-6",
        "4e-6",
        "5e-6",
        "6e-6",
        "7e-6",
        "8e-6",
        "9e-6",
        "1e-5",
        "2e-5",
        "3e-5",
        "4e-5",
        "5e-5",
        "6e-5",
        "7e-5",
        "8e-5",
        "9e-5",
        "1e-4",
        "2e-4",
        "3e-4",
        "4e-4",
        "5e-4",
        "6e-4",
        "7e-4",
        "8e-4",
        "9e-4",
        "1e-3",
      ],
    },
    textEncoder: {
      "sd-1_5": [
        "1e-6",
        "2e-6",
        "3e-6",
        "4e-6",
        "5e-6",
        "6e-6",
        "7e-6",
        "8e-6",
        "9e-6",
        "1e-5",
      ],
      "sd-xl": [
        "1e-7",
        "2e-7",
        "3e-7",
        "4e-7",
        "5e-7",
        "6e-7",
        "7e-7",
        "8e-7",
        "9e-7",
        "1e-6",
        "2e-6",
        "3e-6",
        "4e-6",
        "5e-6",
        "6e-6",
        "7e-6",
        "8e-6",
        "9e-6",
        "1e-5",
        "2e-5",
        "3e-5",
        "4e-5",
        "5e-5",
        "6e-5",
        "7e-5",
        "8e-5",
        "9e-5",
        "1e-4",
      ],
    },
  },
};

export function getLearningRateValue(id: string): number | undefined {
  return LEARNING_RATES.hash[id]?.value;
}

export function getLearningRateLabel(id: string): string | undefined {
  return LEARNING_RATES.hash[id]?.label;
}

export function findLearningRateId(value: number): string | undefined {
  return Object.keys(LEARNING_RATES.hash).find(
    (key) =>
      LEARNING_RATES.hash[key].value === value ||
      Math.abs(LEARNING_RATES.hash[key].value - value) < 9e-10
  );
}

// ------------------------------------

export const SD15_DEFAULT_DATA: {
  [key in TrainForm["flow"]]: Required<
    Pick<
      TrainForm,
      | "learningRateUnet"
      | "textEncoderTrainingRatio"
      | "textEncoderLearningRate"
    >
  >;
} = {
  unguided: {
    learningRateUnet: "2e-6",
    textEncoderTrainingRatio: 0.2,
    textEncoderLearningRate: "1e-6",
  },
  guided: {
    learningRateUnet: "5e-6",
    textEncoderTrainingRatio: 0.15,
    textEncoderLearningRate: "5e-6",
  },
};

// ------------------------------------

export const SDXL_PRESETS = ["style", "subject", "custom"] as const;
export type SdxlPreset = (typeof SDXL_PRESETS)[number];

export const SDXL_DEFAULT_DATA: {
  [key in SdxlPreset]: Required<
    Pick<
      TrainForm,
      | "learningRateUnet"
      | "textEncoderTrainingRatio"
      | "textEncoderLearningRate"
    >
  > &
    Partial<Pick<TrainForm, "totalTrainingSteps">>;
} = {
  subject: {
    totalTrainingSteps: 0,
    learningRateUnet: "5e-5",
    textEncoderTrainingRatio: 0.35,
    textEncoderLearningRate: "1e-6",
  },
  style: {
    totalTrainingSteps: 0,
    learningRateUnet: "5e-5",
    textEncoderTrainingRatio: 0.25,
    textEncoderLearningRate: "1e-6",
  },
  custom: {
    learningRateUnet: "5e-5",
    textEncoderTrainingRatio: 0.25,
    textEncoderLearningRate: "1e-6",
  },
};

// ------------------------------------

export const FLUX_DEFAULT_DATA: Required<
  Pick<TrainForm, "learningRate" | "isOptimizedForLikeness">
> = {
  learningRate: "4e-4",
  isOptimizedForLikeness: false,
};

// ------------------------------------

export const INITIAL_FORM: TrainForm = {
  name: "New Model",
  type: "flux.1-lora",
  preset: SDXL_PRESETS[0],
  flow: "unguided",
  classSlug: "default",
  totalTrainingSteps: 0,
  learningRate: FLUX_DEFAULT_DATA.learningRate,
  learningRateUnet: LEARNING_RATES.lists.unet["sd-1_5"][0],
  textEncoderTrainingRatio: 0,
  textEncoderLearningRate: LEARNING_RATES.lists.textEncoder["sd-1_5"][0],
  isOptimizedForLikeness: false,
  tags: [],
};

// ------------------------------------

export function getDefaultData(
  args:
    | {
        type: "sd-1_5";
        flow: TrainForm["flow"];
      }
    | {
        type: "sd-xl-lora";
        preset: SdxlPreset;
      }
    | {
        type: "flux.1-lora";
      }
): Partial<TrainForm> {
  if (args.type === "sd-1_5") {
    return SD15_DEFAULT_DATA[args.flow];
  } else if (args.type === "sd-xl-lora") {
    return SDXL_DEFAULT_DATA[args.preset];
  } else {
    return FLUX_DEFAULT_DATA;
  }
}
