import {
  getStudioByIdEndpoint,
  LAB_COMPARE_ENDPOINT,
  LAB_EVALUATE_ENDPOINT,
  LAB_EVALUATE_SINGLE_ENDPOINT,
  LAB_FEEDBACK_ENDPOINT,
  LAB_GET_SESSION_RESULTS,
  LAB_GET_SESSIONS_ENDPOINT,
  LAB_INFERENCE_STREAMING_ENDPOINT,
  LAB_SAVE_ENDPOINT,
  LAB_SAVE_IMAGE,
  LAB_SAVE_NEW_VERSION_ENDPOINT,
  LAB_SESSION_ENDPOINT,
  LAB_STREAM_COMPLETION_ENDPOINT,
  LAB_UPDATE_EVAL_METRIC_ID_ENDPOINT,
  LAB_UPDATE_SESSION_NAME_ENDPOINT,
} from '../../utils/constants';

import { useApiClient } from '../useApiClient';
import { useLabStore } from '../../store/useLabStore';
import {
  ConfigurationUnit,
  EvaluateRequestUnit,
  InferenceRequest,
  NewVersionIdsResponseSchema,
  SaveLabSessionRequest,
} from '../../utils/types';
import { useCallback } from 'react';
import { useFetchApiClient } from '../useFetchApiClient';

const useLabApiEndpoints = () => {
  const { streamPostRequest } = useFetchApiClient();
  const { getRequest, postRequest, deleteRequest } = useApiClient();
  const { setError, sessions, setSessions, setErrorEvaluateEndpoint } =
    useLabStore((state) => ({
      setError: state.setError,
      sessions: state.sessions,
      setSessions: state.setSessions,
      setErrorEvaluateEndpoint: state.setErrorEvaluateEndpoint,
    }));

  const withError = (setOneError, setOtherError, method = 'POST') => {
    return async (requestType, url, data = null, headers = {}) => {
      try {
        setOneError();
        if (setOtherError) setOtherError();
        if (method === 'POST') return await requestType(url, data, headers);
        return await requestType(url, headers);
      } catch (err) {
        console.error(err);
        setOneError(err.detail);
        throw err;
      }
    };
  };

  const handleGETRequest = withError(setError, setErrorEvaluateEndpoint, 'GET');
  const handleRequest = withError(setError, setErrorEvaluateEndpoint);
  const handleEvaluateRequest = withError(setErrorEvaluateEndpoint);

  const getLabSessions = async () => {
    const response = await handleGETRequest(
      getRequest,
      LAB_GET_SESSIONS_ENDPOINT
    );
    return response.data;
  };

  const getInferenceResults = async (result_ids: number[][]) => {
    const response = await handleRequest(postRequest, LAB_GET_SESSION_RESULTS, {
      result_ids: result_ids,
    });
    return response.data;
  };

  const getVersionData = async (id) => {
    const response = await handleGETRequest(
      getRequest,
      getStudioByIdEndpoint(id)
    );
    return response.data;
  };

  const deleteLabSession = async (sessionIdx) => {
    // Optimistically update the UI by removing the session from the sessions state
    const sessionId = sessions[sessionIdx].id;
    const updatedSessions = { ...sessions };
    const sessionToDelete = updatedSessions[sessionIdx];
    delete updatedSessions[sessionIdx];
    setSessions(updatedSessions);

    // Make the API call to delete the session
    const resp = await handleRequest(
      deleteRequest,
      `${LAB_SESSION_ENDPOINT}/${sessionId}`,
      {}
    );
    if (!resp) {
      // If the API call fails, revert the deletion by adding the session back to the sessions state
      setSessions({ ...sessions, [sessionIdx]: sessionToDelete });
    }
  };

  const makeInferenceCalls = async (req: InferenceRequest[]) => {
    return await handleRequest(postRequest, LAB_COMPARE_ENDPOINT, req);
  };

  const makeInferenceCallStream = async (req: InferenceRequest, index) => {
    return await streamPostRequest(
      LAB_INFERENCE_STREAMING_ENDPOINT,
      req,
      index
    );
  };

  const getFinishedStreamData = async (req) => {
    const response = await handleRequest(
      postRequest,
      LAB_STREAM_COMPLETION_ENDPOINT,
      req
    );
    return response.data;
  };

  const saveLabSession = async (req: SaveLabSessionRequest) => {
    const response = await handleRequest(postRequest, LAB_SAVE_ENDPOINT, req);
    return response.data;
  };

  const updateSessionName = useCallback(
    async (sessionId, newName) => {
      await handleRequest(
        postRequest,
        `${LAB_UPDATE_SESSION_NAME_ENDPOINT}/${sessionId}`,
        { name: newName }
      );
      return { data: sessionId };
    },
    [postRequest, handleRequest]
  );

  const updateSessionEvalMetricIds = useCallback(
    async (sessionId, newEvalMetricIds) => {
      await handleRequest(
        postRequest,
        `${LAB_UPDATE_EVAL_METRIC_ID_ENDPOINT}/${sessionId}`,
        { evaluation_metric_ids: newEvalMetricIds }
      );
      return { data: sessionId };
    },
    [postRequest, handleRequest]
  );

  const evaluateSingleAPI = useCallback(
    async (evalMetricIds: number[], unit: EvaluateRequestUnit) => {
      return await handleEvaluateRequest(
        postRequest,
        LAB_EVALUATE_SINGLE_ENDPOINT,
        {
          evaluation_metric_ids: evalMetricIds,
          unit: unit,
        }
      );
    },
    [handleEvaluateRequest, postRequest]
  );

  const evaluateAutoAPI = useCallback(
    async (
      autoEvalNames: string[],
      unit: EvaluateRequestUnit,
      autoEvalInputs
    ) => {
      return await handleRequest(postRequest, '/api/lab/evaluate/auto', {
        evaluation_metric_ids: [],
        unit: unit,
        auto_eval_names: autoEvalNames,
        auto_eval_inputs: autoEvalInputs,
      });
    },
    [handleRequest, postRequest]
  );

  const saveNewVersion = useCallback(
    async (
      configuration: ConfigurationUnit
    ): Promise<{ data: NewVersionIdsResponseSchema, status: number }> => {
      return await handleRequest(
        postRequest,
        LAB_SAVE_NEW_VERSION_ENDPOINT,
        configuration
      );
    },
    [handleRequest, postRequest]
  );

  const evaluateAPI = useCallback(
    async (evalMetricIds: number[], units: EvaluateRequestUnit[]) => {
      return await handleEvaluateRequest(postRequest, LAB_EVALUATE_ENDPOINT, {
        evaluation_metric_ids: evalMetricIds,
        units: units,
      });
    },
    [handleEvaluateRequest, postRequest]
  );

  const collectFeedbackAPI = useCallback(
    async (resultId: number, feedback) => {
      return await handleRequest(postRequest, LAB_FEEDBACK_ENDPOINT, {
        result_id: resultId,
        feedback: feedback,
      });
    },
    [postRequest, handleRequest]
  );

  const saveImageAPI = useCallback(
    async (file) => {
      const formData = new FormData();
      formData.append('image', file);

      return await handleRequest(postRequest, LAB_SAVE_IMAGE, formData, {
        'Content-Type': 'multipart/form-data',
      });
    },
    [postRequest, handleRequest]
  );

  return {
    getLabSessions,
    getVersionData,
    deleteLabSession,
    makeInferenceCalls,
    getInferenceResults,
    updateSessionName,
    saveLabSession,
    updateSessionEvalMetricIds,
    evaluateAPI,
    evaluateSingleAPI,
    collectFeedbackAPI,
    makeInferenceCallStream,
    getFinishedStreamData,
    evaluateAutoAPI,
    saveNewVersion,
    saveImageAPI,
  };
};

export default useLabApiEndpoints;
