import { useCallback, useMemo, useRef, useState } from "react";
import { useAssetReframeCu } from "domains/assets/hooks/useAssetReframeCu";
import { InferenceBuilderErrors } from "domains/canvas/hooks/useInferenceBuilderErrors";
import { makeNumberAMultipleOfEight } from "domains/commons/misc";
import { useSectionsContext } from "domains/inference/contexts/SectionsProvider";
import { getIsModelFlux } from "domains/models/utils";
import { useScenarioToast } from "domains/notification/hooks/useScenarioToast";
import {
  AspectRatioValue,
  ratioStringToNumber,
} from "domains/reframe/components/ReframeParams";
import { ReframeForm } from "domains/reframe/interfaces/Reframe";
import { useTeamContext } from "domains/teams/contexts/TeamProvider";
import { AnalyticsEvents } from "infra/analytics/constants/Events";
import Track from "infra/analytics/Track";
import { removeCuFromType } from "infra/api/ExludeCU";
import {
  GetAssetsByAssetIdApiResponse,
  GetJobIdApiResponse,
  GetModelsByModelIdApiResponse,
  PostReframeInferencesApiArg,
  useLazyGetJobIdQuery,
  useLazyGetModelsByModelIdQuery,
  usePostReframeInferencesMutation,
} from "infra/api/generated/api";

import { useDebounceCallback } from "@react-hook/debounce";

export const DEFAULT_REFRAME_FORM: ReframeForm = {
  asset: undefined,
  prompt: "",
  targetWidth: 1_024,
  targetHeight: 1_024,
  numInferenceSteps: 10,
  inputLocation: "middle",
  overlapPercentage: 30,
  inputWidth: 1_024,
  inputHeight: 1_024,
  inputScale: 1,
  outputScale: 1.2,
};

export interface UseAssetReframeReturnValue {
  form: ReframeForm;
  setForm: React.Dispatch<React.SetStateAction<ReframeForm>>;
  setValue: <K extends keyof ReframeForm>(
    key: K,
    value: ReframeForm[K]
  ) => void;
  handleReframeJob: (props: {
    asset: GetAssetsByAssetIdApiResponse["asset"];
    trackingExtraParams: Record<string, unknown>;
  }) => Promise<GetJobIdApiResponse["job"] | undefined>;
  isReframeLoading: boolean;
  cuCost?: number;
  isCuLoading: boolean;
  errors: InferenceBuilderErrors;
  isDisabled: boolean;
}

export function useAssetReframe(): UseAssetReframeReturnValue {
  const { selectedProject } = useTeamContext();
  const { errorToast, infoToast } = useScenarioToast();
  const [createReframeTrigger] = usePostReframeInferencesMutation();
  const [getJobTrigger] = useLazyGetJobIdQuery();
  const [getModelTrigger] = useLazyGetModelsByModelIdQuery();
  const { toggleCollapsedSection } = useSectionsContext();

  const [form, setForm] = useState<ReframeForm>(DEFAULT_REFRAME_FORM);
  const [isReframeLoading, setIsReframeLoading] = useState(false);
  const debouncedInfoToast = useDebounceCallback(infoToast, 5_000, true);

  const concepts = useRef<{
    isLoading: boolean;
    concepts:
      | (NonNullable<GetModelsByModelIdApiResponse["model"]["concepts"]>[0] & {
          weightName: "lora.safetensors";
        })[]
      | undefined;
  }>({
    isLoading: false,
    concepts: undefined,
  });

  const calculateDimensions = useCallback(
    ({
      baseWidth,
      baseHeight,
      aspectRatio,
      outputScale,
    }: {
      baseWidth: number;
      baseHeight: number;
      aspectRatio?: AspectRatioValue;
      outputScale: number;
    }) => {
      const ratio = ratioStringToNumber(aspectRatio ?? "Default", {
        width: baseWidth,
        height: baseHeight,
      });

      let targetWidth: number;
      let targetHeight: number;

      if (ratio === 1) {
        const side = Math.max(baseWidth, baseHeight);
        targetWidth = side;
        targetHeight = side;
      } else if (ratio > 1) {
        targetHeight = baseHeight;
        targetWidth = baseHeight * ratio;
      } else {
        targetWidth = baseWidth;
        targetHeight = baseWidth / ratio;
      }

      targetWidth = Math.round(targetWidth * outputScale);
      targetHeight = Math.round(targetHeight * outputScale);

      const updates: Partial<ReframeForm> = {
        targetWidth,
        targetHeight,
        outputScale,
        inputWidth: baseWidth,
        inputHeight: baseHeight,
        inputScale: 1,
      };

      const MAX_SIZE = 2_048;
      if (targetWidth > MAX_SIZE || targetHeight > MAX_SIZE) {
        const scale = Math.min(MAX_SIZE / targetWidth, MAX_SIZE / targetHeight);

        if (scale < 1) {
          updates.targetWidth = targetWidth * scale;
          updates.targetHeight = targetHeight * scale;
          updates.inputScale = scale;
          updates.inputWidth = baseWidth * scale;
          updates.inputHeight = baseHeight * scale;
        }
      }

      return updates;
    },
    []
  );

  const setValue = useCallback(
    async <T extends keyof ReframeForm>(key: T, value: ReframeForm[T]) => {
      let updates: Partial<ReframeForm> = { [key]: value };

      // Handle asset change and concepts
      if (key === "asset" && value) {
        // Load concepts from parent job if exists
        concepts.current = { isLoading: true, concepts: undefined };
        const newConcepts: typeof concepts.current.concepts = [];
        const assetForConcepts = value as ReframeForm["asset"];

        if (assetForConcepts?.metadata.parentJobId) {
          const job = await getJobTrigger({
            projectId: selectedProject.id,
            jobId: assetForConcepts.metadata.parentJobId,
          });

          const jobConcepts = (job?.data?.job?.metadata.input as any)?.concepts;
          if (jobConcepts) {
            for (const concept of jobConcepts) {
              const model = await getModelTrigger({
                projectId: selectedProject.id,
                modelId: concept.modelId,
              });
              if (model?.data?.model && getIsModelFlux(model.data.model)) {
                newConcepts.push({
                  modelId: model.data.model.id,
                  scale: concept.scale,
                  weightName: "lora.safetensors",
                });
              }
            }
          }
        }
        concepts.current = { isLoading: false, concepts: newConcepts };

        // Update form based on asset dimensions
        const newAsset = value as ReframeForm["asset"];
        const assetWidth = newAsset?.metadata.width ?? 512;
        const assetHeight = newAsset?.metadata.height ?? 512;

        if (newAsset?.metadata.prompt) {
          toggleCollapsedSection("prompt", false);
          updates.prompt = newAsset.metadata.prompt;
        }

        updates = {
          ...updates,
          ...calculateDimensions({
            baseWidth: assetWidth,
            baseHeight: assetHeight,
            aspectRatio: form.aspectRatio,
            outputScale: DEFAULT_REFRAME_FORM.outputScale,
          }),
        };
      }

      // Handle aspect ratio or output scale changes
      if (
        (key === "aspectRatio" || key === "outputScale") &&
        form.asset?.metadata?.width &&
        form.asset.metadata?.height
      ) {
        updates = {
          ...updates,
          ...calculateDimensions({
            baseWidth: form.asset.metadata.width,
            baseHeight: form.asset.metadata.height,
            aspectRatio:
              key === "aspectRatio"
                ? (value as AspectRatioValue)
                : form.aspectRatio,
            outputScale:
              key === "outputScale" ? (value as number) : form.outputScale,
          }),
        };
      }

      // Handle input scale change
      if (
        key === "inputScale" &&
        form.asset?.metadata?.width &&
        form.asset.metadata?.height
      ) {
        const inputScale = value as number;
        updates.inputWidth = form.asset.metadata.width * inputScale;
        updates.inputHeight = form.asset.metadata.height * inputScale;
        updates.inputScale = inputScale;
      }

      // Make dimensions multiple of 8
      if (updates.targetWidth) {
        updates.targetWidth = makeNumberAMultipleOfEight(updates.targetWidth);
      }
      if (updates.targetHeight) {
        updates.targetHeight = makeNumberAMultipleOfEight(updates.targetHeight);
      }

      // Show info toast for input scale changes under 1
      if (
        updates.inputScale !== undefined &&
        key !== "inputScale" &&
        form.inputScale === 1 &&
        updates.inputScale < 1
      ) {
        debouncedInfoToast({
          description:
            "Your input has been resized to meet the maximum output resolution",
        });
      }

      setForm((prev) => ({ ...prev, ...updates }));
    },
    [
      form.asset?.metadata.width,
      form.asset?.metadata.height,
      form.inputScale,
      form.aspectRatio,
      form.outputScale,
      getJobTrigger,
      selectedProject.id,
      getModelTrigger,
      toggleCollapsedSection,
      debouncedInfoToast,
      calculateDimensions,
    ]
  );

  const errors = useMemo(() => {
    const errors: InferenceBuilderErrors = {};

    if (!form.asset) {
      return errors;
    }

    if (
      form.inputWidth >= form.targetWidth &&
      form.inputHeight >= form.targetHeight
    ) {
      errors.inputDimensionsTooBig = {
        message:
          "Input dimensions cannot be larger or equal than target dimensions in both width and height",
        type: "error",
      };
    }

    return errors;
  }, [
    form.asset,
    form.inputHeight,
    form.inputWidth,
    form.targetHeight,
    form.targetWidth,
  ]);

  const body = useMemo(() => {
    if (form.asset) {
      return createReframeBody({
        ...form,
        asset: form.asset,
      });
    }
    return undefined;
  }, [form]);

  const isDisabled =
    !form.asset ||
    Object.values(errors).some((error) => error.type === "error");

  const { cuCost, isCuLoading } = useAssetReframeCu(
    isDisabled ? undefined : body
  );

  const handleReframeJob = useCallback(
    async ({
      asset,
      trackingExtraParams,
    }: {
      asset: GetAssetsByAssetIdApiResponse["asset"];
      trackingExtraParams: Record<string, unknown>;
    }) => {
      if (isDisabled) {
        return;
      }

      try {
        setIsReframeLoading(true);

        while (concepts.current.isLoading) {
          await new Promise((resolve) => setTimeout(resolve, 500));
        }

        const response = await createReframeTrigger({
          projectId: selectedProject.id,
          body: {
            ...createReframeBody({ ...form, asset }),
            concepts: concepts.current.concepts,
          },
        })
          .unwrap()
          .then(removeCuFromType);

        Track(AnalyticsEvents.Reframe.GeneratedReframe, {
          ...form,
          ...trackingExtraParams,
        });

        return response.job;
      } catch (error: any) {
        errorToast({
          title: "Expand failed",
          description:
            "There was an error expanding the image. Please try again.",
        });
      } finally {
        setIsReframeLoading(false);
      }
    },
    [createReframeTrigger, errorToast, form, isDisabled, selectedProject.id]
  );

  return {
    form,
    setForm,
    setValue,
    handleReframeJob,
    isReframeLoading,
    cuCost,
    isCuLoading,
    errors,
    isDisabled,
  };
}

export function createReframeBody({
  asset,
  inputScale,
  overlapPercentage,
  ...form
}: ReframeForm & {
  asset: GetAssetsByAssetIdApiResponse["asset"];
}): PostReframeInferencesApiArg["body"] {
  return {
    image: asset.id,
    resizeOption: inputScale,
    overlapPercentage: overlapPercentage / 2 / 100,
    targetHeight: form.targetHeight,
    targetWidth: form.targetWidth,
    inputLocation: form.inputLocation,
    numInferenceSteps: form.numInferenceSteps,
    prompt: form.prompt,
    seed: form.seed,
  };
}
