import React, {
  Dispatch,
  SetStateAction,
  useCallback,
  useEffect,
  useMemo,
  useState,
} from "react";
import { useDebounce } from "domains/commons/hooks/useDebounce";
import ModalReviewCaption from "domains/models/components/ModelTrain/ModalReviewCaption";
import SelectClass from "domains/models/components/ModelTrain/SelectClass";
import { ModelTag } from "domains/models/constants/tags";
import {
  FormType,
  getDefaultData,
  getLearningRateLabel,
  LEARNING_RATES,
  SDXL_DEFAULT_DATA,
  SdxlPreset,
} from "domains/models/constants/train";
import { TrainForm } from "domains/models/interfaces/train";
import { getTrainingParameters } from "domains/models/utils/train";
import TagDisplayWithRemove from "domains/tags/components/TagDisplayWithRemove";
import TagInput from "domains/tags/components/TagInput";
import TagSuggestion from "domains/tags/components/TagSuggestion";
import { useTeamContext } from "domains/teams/contexts/TeamProvider";
import Button from "domains/ui/components/Button";
import ButtonWithCuIndicator from "domains/ui/components/ButtonWithCuIndicator";
import { ButtonSelect } from "domains/ui/components/ControlButtonSelect";
import Sidebar, { SidebarSection } from "domains/ui/components/Sidebar";
import SidebarTitle from "domains/ui/components/Sidebar/SidebarTitle";
import Slider from "domains/ui/components/Slider";
import SliderStepsPaul from "domains/ui/components/Slider/SliderStepsPaul";
import WithLabelAndTooltip from "domains/ui/components/WithLabelAndTooltip";
import { useUserContext } from "domains/user/contexts/UserProvider";
import { GetModelsByModelIdApiResponse } from "infra/api/generated/api";
import { useGetModelTrainCostQuery } from "infra/store/apiSlice";
import _ from "lodash";

import {
  Input,
  SimpleGrid,
  Skeleton,
  Switch,
  Tooltip,
  VStack,
} from "@chakra-ui/react";
import { useAutoAnimate } from "@formkit/auto-animate/react";
import { skipToken } from "@reduxjs/toolkit/query";

interface ModelTrainSidebarProps {
  model: GetModelsByModelIdApiResponse["model"] | undefined;
  isTitleHidden?: boolean;
  form: TrainForm | undefined;
  setForm: Dispatch<SetStateAction<TrainForm | undefined>>;
  defaultTrainingSteps: number;
  minTrainingSteps: number;
  maxTrainingSteps: number;
  isSavable?: boolean;
  isSaving?: boolean;
  isTrainable?: boolean;
  isTraining?: boolean;
  trainingError?:
    | "notEnoughTrainingImages"
    | "tooMuchTrainingImages"
    | "missingClassSlug"
    | "lowResTrainingImages"
    | "imageSizesLoading"
    | "guessingDescription";
  onSaveClick: () => void;
  onTrainClick: () => void;
  haveCaptionsBeenUpdated: boolean;
}

export default function ModelTrainSidebar({
  model,
  isTitleHidden,
  form,
  setForm,
  defaultTrainingSteps,
  minTrainingSteps,
  maxTrainingSteps,
  isSavable,
  isSaving,
  isTrainable,
  isTraining,
  trainingError,
  onSaveClick,
  onTrainClick,
  haveCaptionsBeenUpdated,
}: ModelTrainSidebarProps) {
  const { featureFlags } = useUserContext();
  const { selectedProject } = useTeamContext();
  const [animationRef] = useAutoAnimate();
  const [isCaptionModalOpen, setIsCaptionModalOpen] = useState<boolean>(false);
  const [hasAlreadyShownCaptionModal, setHasAlreadyShownCaptionModal] =
    useState<boolean>(false);

  const isSd15Available = featureFlags.includes("sd15-training");
  const isSkeleton = !form;

  const datasetSize = model?.trainingImages?.length ?? 0;
  const name = form?.name;
  const type = form?.type;
  const preset = form?.preset;
  const flow = form?.flow;
  const totalTrainingSteps = form?.totalTrainingSteps;
  const learningRate = form?.learningRate;
  const learningRateUnet = form?.learningRateUnet;
  const textEncoderTrainingRatio = form?.textEncoderTrainingRatio || 0;
  const textEncoderLearningRate = form?.textEncoderLearningRate;
  const isOptimizedForLikeness = form?.isOptimizedForLikeness;

  const learningRateSteps = useMemo(
    () =>
      (type?.startsWith("flux.1") && {
        global: LEARNING_RATES.lists.global["flux.1"],
      }) ||
      (type?.startsWith("sd-xl") && {
        unet: LEARNING_RATES.lists.unet["sd-xl"],
        textEncoder: LEARNING_RATES.lists.textEncoder["sd-xl"],
      }) || {
        unet: LEARNING_RATES.lists.unet["sd-1_5"],
        textEncoder: LEARNING_RATES.lists.textEncoder["sd-1_5"],
      },
    [type]
  );

  const parameters = useMemo(() => {
    return model && form && isTrainable
      ? getTrainingParameters(form, datasetSize, { isBeforeTrain: true })
      : undefined;
  }, [model, form, datasetSize, isTrainable]);
  const debouncedParameters = useDebounce(parameters, 500);

  const { data: cuCostData, isFetching: isCuLoading } =
    useGetModelTrainCostQuery(
      !isTrainable || !model || !debouncedParameters
        ? skipToken
        : {
            projectId: selectedProject.id,
            modelId: debouncedParameters.dummyModel,
            trainingImagesCount: datasetSize,
            body: {
              parameters: debouncedParameters.body.parameters,
            },
          }
    );

  // ----------------------------------

  const handleFormChangeType = useCallback(
    (value: FormType) => {
      if (!form) return;

      const defaultData = (() => {
        if (value === "sd-1_5") {
          return getDefaultData({
            type: "sd-1_5",
            flow: form.flow,
          });
        } else if (value === "sd-xl-lora") {
          return getDefaultData({
            type: "sd-xl-lora",
            preset: form.preset,
          });
        } else {
          return getDefaultData({
            type: "flux.1-lora",
          });
        }
      })();

      setForm({
        ...form,
        type: value,
        ...defaultData,
      });
    },
    [setForm, form]
  );

  const handleFormChangePreset = useCallback(
    (value: SdxlPreset) => {
      if (form?.type !== "sd-xl-lora") return;
      const defaultData =
        value !== "custom"
          ? getDefaultData({
              type: "sd-xl-lora",
              preset: value,
            })
          : {};
      setForm({
        ...form,
        preset: value,
        ...defaultData,
      });
    },
    [setForm, form]
  );

  const handleFormChangeFlow = useCallback(
    (value: "guided" | "unguided") => {
      if (form?.type !== "sd-1_5") return;
      const defaultData = getDefaultData({
        type: "sd-1_5",
        flow: value,
      });
      setForm({
        ...form,
        flow: value,
        ...defaultData,
      });
    },
    [setForm, form]
  );

  const handleTrainClick = useCallback(() => {
    if (!haveCaptionsBeenUpdated && !hasAlreadyShownCaptionModal) {
      setIsCaptionModalOpen(true);
      setHasAlreadyShownCaptionModal(true);
    } else {
      onTrainClick();
    }
  }, [hasAlreadyShownCaptionModal, haveCaptionsBeenUpdated, onTrainClick]);

  // ----------------------------------

  const changeFormItem = useCallback(
    <T extends keyof TrainForm>(key: T, value: TrainForm[T]) => {
      setForm((form) =>
        form
          ? {
              ...form,
              [key]: value,
            }
          : form
      );
    },
    [setForm]
  );

  const updatePreset = useCallback(
    (presetType: "custom" | "subject" | "style") => {
      handleFormChangePreset(presetType);

      const currentTags = form?.tags ?? [];
      const tagsMap: { [key: string]: ModelTag } = {
        style: "sc:style",
        subject: "sc:subject",
      };

      const tagsToRemove = Object.keys(tagsMap)
        .filter((tag) => tag !== presetType)
        .map((tag) => tagsMap[tag]);

      changeFormItem(
        "tags",
        _.uniq([
          ...currentTags.filter((tag) => !tagsToRemove.includes(tag)),
          ...(presetType === "custom" ? [] : [tagsMap[presetType]]),
        ])
      );
    },
    [changeFormItem, form?.tags, handleFormChangePreset]
  );

  // ----------------------------------

  useEffect(() => {
    if (!form || form.type !== "sd-xl-lora" || form.preset === "custom") {
      return;
    }

    const defaultData = SDXL_DEFAULT_DATA[form.preset];
    if (!_.isEqual(defaultData, _.pick(form, _.keys(defaultData)))) {
      handleFormChangePreset("custom");
    }
  }, [handleFormChangePreset, form]);

  // ----------------------------------

  return (
    <>
      <Sidebar
        id="modelTrain"
        isMobile={false}
        title={
          isTitleHidden ? undefined : (
            <SidebarTitle
              title="Train a Model"
              guide="training-intro"
              videoUrl="https://www.youtube.com/embed/T8G73YoblQA?si=8xbCE9Q9T1lw8UAW"
              videoTooltip="Watch a quick video to learn about custom model training"
            />
          )
        }
      >
        <SidebarSection
          id="top"
          collapseProps={{
            style: {
              overflow: "visible",
            },
          }}
        >
          <VStack ref={animationRef} spacing={4}>
            <Skeleton w="full" borderRadius="md" isLoaded={!isSkeleton}>
              <Input
                bgColor="background.500"
                onChange={(e) => changeFormItem("name", e.target.value)}
                placeholder="Model Name"
                type="text"
                value={name}
              />
            </Skeleton>

            <WithLabelAndTooltip direction="column" label="Base Model">
              <SimpleGrid columns={isSd15Available ? 4 : 3} spacing={2}>
                {isSkeleton ? (
                  <>
                    <Skeleton h="36px" borderRadius="md" />
                    <Skeleton h="36px" borderRadius="md" />
                  </>
                ) : (
                  <>
                    <ButtonSelect
                      isSelected={type === "flux.1-lora"}
                      onClick={() => handleFormChangeType("flux.1-lora")}
                      tooltip="Flux training includes commercial rights."
                    >
                      Flux
                    </ButtonSelect>

                    <ButtonSelect
                      isSelected={type === "sd-xl-lora"}
                      onClick={() => handleFormChangeType("sd-xl-lora")}
                    >
                      SDXL
                    </ButtonSelect>

                    {/* We want to test the market regarding Bria training */}
                    <ButtonSelect
                      isDisabled
                      isSelected={false}
                      onClick={() => {}}
                      tooltip="Training on Bria is not enabled for your account. Please contact us for more details."
                    >
                      Bria
                    </ButtonSelect>

                    {isSd15Available && (
                      <ButtonSelect
                        isSelected={type === "sd-1_5"}
                        onClick={() => handleFormChangeType("sd-1_5")}
                      >
                        SD 1.5
                      </ButtonSelect>
                    )}
                  </>
                )}
              </SimpleGrid>
            </WithLabelAndTooltip>

            {type === "sd-xl-lora" && (
              <WithLabelAndTooltip
                direction="column"
                label="Preset"
                guide={preset ? `training-type-${preset}` : undefined}
              >
                <SimpleGrid columns={1} spacing={2}>
                  {isSkeleton ? (
                    <>
                      <Skeleton h="36px" borderRadius="md" />
                      <Skeleton h="36px" borderRadius="md" />
                      <Skeleton h="36px" borderRadius="md" />
                    </>
                  ) : (
                    <>
                      <ButtonSelect
                        isSelected={preset === "style"}
                        onClick={() => updatePreset("style")}
                      >
                        Style
                      </ButtonSelect>

                      <Tooltip
                        label="Also suitable for consistent object training"
                        placement="right"
                      >
                        <VStack align="stretch" w="100%">
                          <ButtonSelect
                            isSelected={preset === "subject"}
                            onClick={() => updatePreset("subject")}
                          >
                            Subject
                          </ButtonSelect>
                        </VStack>
                      </Tooltip>

                      <ButtonSelect
                        isSelected={preset === "custom"}
                        onClick={() => updatePreset("custom")}
                      >
                        Custom (Advanced)
                      </ButtonSelect>
                    </>
                  )}
                </SimpleGrid>
              </WithLabelAndTooltip>
            )}

            {type === "sd-1_5" && (
              <WithLabelAndTooltip direction="column" label="Training Flow">
                <SimpleGrid columns={2} spacing={2}>
                  {isSkeleton ? (
                    <>
                      <Skeleton h="36px" borderRadius="md" />
                      <Skeleton h="36px" borderRadius="md" />
                    </>
                  ) : (
                    <>
                      <ButtonSelect
                        isSelected={flow === "unguided"}
                        onClick={() => handleFormChangeFlow("unguided")}
                      >
                        Unguided
                      </ButtonSelect>

                      <ButtonSelect
                        isSelected={flow === "guided"}
                        onClick={() => handleFormChangeFlow("guided")}
                      >
                        Guided
                      </ButtonSelect>
                    </>
                  )}
                </SimpleGrid>
              </WithLabelAndTooltip>
            )}

            {type === "sd-1_5" && flow === "guided" && (
              <WithLabelAndTooltip label="Training Class">
                <SelectClass
                  setSelectedClassSlug={(value) =>
                    changeFormItem("classSlug", value)
                  }
                  selectedClassSlug={form?.classSlug}
                />
              </WithLabelAndTooltip>
            )}
          </VStack>
        </SidebarSection>

        <SidebarSection id="action">
          <VStack align="stretch" spacing={2}>
            <Tooltip
              isDisabled={isSkeleton || isTrainable}
              label={
                trainingError &&
                {
                  notEnoughTrainingImages:
                    "Add at least 5 images to start training.",
                  tooMuchTrainingImages:
                    "Image limit exceeded. Remove some images to start training.",
                  lowResTrainingImages: `Low-Resolution Image Detected. Please upscale your image(s) before initiating training.`,
                  imageSizesLoading:
                    "Your images are being analyzed. This may take a few seconds.",
                  guessingDescription:
                    "A caption is required for all training images to initiate the training process.",
                  missingClassSlug:
                    "Select a training class to start training.",
                }[trainingError]
              }
              placement="right"
            >
              <VStack align="stretch" w="100%">
                <ButtonWithCuIndicator
                  variant="primary"
                  size="sm"
                  isDisabled={isSkeleton || !isTrainable}
                  isLoading={isTraining}
                  onClick={handleTrainClick}
                  cuCost={
                    isTrainable && parameters
                      ? cuCostData?.creativeUnitsCost
                      : undefined
                  }
                  isCuLoading={
                    isCuLoading || !_.isEqual(parameters, debouncedParameters)
                  }
                >
                  Start Training
                </ButtonWithCuIndicator>
              </VStack>
            </Tooltip>

            <Button
              variant="secondary"
              size="sm"
              isDisabled={isSkeleton || !isSavable}
              isLoading={isSaving}
              onClick={onSaveClick}
            >
              Save As Draft
            </Button>
          </VStack>
        </SidebarSection>

        <SidebarSection id="advanced" title="Advanced Settings">
          {isSkeleton ? (
            <VStack align="stretch" spacing={4}>
              <VStack align="flex-start" py={2} spacing={3}>
                <Skeleton w="140px" h="12px" />
                <Skeleton w="200px" h="12px" />
              </VStack>

              <VStack align="flex-start" py={2} spacing={3}>
                <Skeleton w="130px" h="12px" />
                <Skeleton w="220px" h="12px" />
              </VStack>

              <VStack align="flex-start" py={2} spacing={3}>
                <Skeleton w="150px" h="12px" />
                <Skeleton w="170px" h="12px" />
              </VStack>
            </VStack>
          ) : (
            <VStack align="stretch" spacing={4}>
              <WithLabelAndTooltip
                labelProps={{ w: "full" }}
                label="Automatic Training Steps"
              >
                <Switch
                  isChecked={!totalTrainingSteps}
                  onChange={() =>
                    changeFormItem(
                      "totalTrainingSteps",
                      totalTrainingSteps === 0 ? defaultTrainingSteps : 0
                    )
                  }
                />
              </WithLabelAndTooltip>

              <WithLabelAndTooltip
                direction="column"
                label="Total Training Steps"
                guide="training-parameter-steps"
              >
                <Slider
                  withNumberInput
                  valueNumberInputProps={{
                    pos: "absolute",
                    top: 0,
                    right: 0,
                    w: "60px",
                    mt: -0.5,
                  }}
                  max={maxTrainingSteps || 500}
                  min={minTrainingSteps || 1}
                  onChange={(value) => {
                    changeFormItem("totalTrainingSteps", value);
                  }}
                  step={10}
                  isDisabled={maxTrainingSteps === 0}
                  value={totalTrainingSteps || defaultTrainingSteps || 350}
                />
              </WithLabelAndTooltip>

              {type === "flux.1-lora" && (
                <>
                  <WithLabelAndTooltip
                    direction="column"
                    label="Learning Rate"
                    guide="training-parameter-learning-rate"
                    rightLabel={
                      (learningRate && getLearningRateLabel(learningRate)) || ""
                    }
                  >
                    <SliderStepsPaul
                      hideValue
                      steps={learningRateSteps.global ?? []}
                      onChange={(value) => {
                        changeFormItem("learningRate", value);
                      }}
                      value={learningRate}
                    />
                  </WithLabelAndTooltip>

                  <WithLabelAndTooltip
                    labelProps={{ w: "full" }}
                    label="Optimize for Portraits"
                    tooltip="Fast training for human portraits"
                  >
                    <Switch
                      isChecked={isOptimizedForLikeness}
                      onChange={() =>
                        changeFormItem(
                          "isOptimizedForLikeness",
                          !isOptimizedForLikeness
                        )
                      }
                    />
                  </WithLabelAndTooltip>
                </>
              )}

              {_.includes(["sd-xl-lora", "sd-1_5"], type) && (
                <>
                  <WithLabelAndTooltip
                    direction="column"
                    label={`UNet Learning Rate`}
                    guide="training-parameter-learning-rate"
                    rightLabel={
                      (learningRateUnet &&
                        getLearningRateLabel(learningRateUnet)) ||
                      ""
                    }
                  >
                    <SliderStepsPaul
                      hideValue
                      steps={learningRateSteps.unet ?? []}
                      onChange={(value) => {
                        changeFormItem("learningRateUnet", value);
                      }}
                      value={learningRateUnet}
                    />
                  </WithLabelAndTooltip>

                  <WithLabelAndTooltip
                    direction="column"
                    label={`Text Encoder Training Ratio`}
                    guide="training-parameter-text-encoder-ratio"
                    rightLabel={textEncoderTrainingRatio.toLocaleString(
                      "en-US",
                      {
                        minimumFractionDigits: 2,
                        maximumFractionDigits: 2,
                      }
                    )}
                  >
                    <Slider
                      hideValue
                      max={0.99}
                      min={0}
                      onChange={(value) => {
                        changeFormItem("textEncoderTrainingRatio", value);
                      }}
                      step={0.05}
                      value={textEncoderTrainingRatio}
                    />
                  </WithLabelAndTooltip>

                  <WithLabelAndTooltip
                    direction="column"
                    label={`Text Encoder Learning Rate`}
                    rightLabel={
                      (textEncoderLearningRate &&
                        getLearningRateLabel(textEncoderLearningRate)) ||
                      ""
                    }
                  >
                    <SliderStepsPaul
                      hideValue
                      steps={learningRateSteps.textEncoder ?? []}
                      onChange={(value) =>
                        changeFormItem("textEncoderLearningRate", value)
                      }
                      value={textEncoderLearningRate}
                    />
                  </WithLabelAndTooltip>
                </>
              )}
            </VStack>
          )}
        </SidebarSection>

        <SidebarSection id="tags" title="Tags">
          <VStack align="stretch">
            <TagDisplayWithRemove
              tags={form?.tags ?? []}
              onRemove={(tag) =>
                changeFormItem(
                  "tags",
                  (form?.tags ?? []).filter((t) => t !== tag)
                )
              }
            />
            <TagInput
              currentTags={form?.tags ?? []}
              onAddTags={(tags) =>
                changeFormItem("tags", [...(form?.tags ?? []), ...tags])
              }
            />
            <TagSuggestion
              tags={form?.tags ?? []}
              onUpdateTags={(tags) => changeFormItem("tags", tags)}
            />
          </VStack>
        </SidebarSection>
      </Sidebar>
      <ModalReviewCaption
        isOpen={isCaptionModalOpen}
        onClose={() => setIsCaptionModalOpen(false)}
        onTrain={onTrainClick}
      />
    </>
  );
}
