import { create } from 'zustand';
import { immer } from 'zustand/middleware/immer';
import type {
  Inference,
  InputsCellsObj,
  LabState,
  SessionObj,
  VersionItemsObj,
  VersionObj,
} from '../utils/types';
import {
  CommentObj,
  InferenceObj,
  InputsCellType,
  LabSessionComment,
} from '../utils/types';
import { incrementId } from '../utils/helpers';

const showingLogprobs = localStorage.getItem(`showingLogprobs`) === 'true';

const initialState: LabState = {
  error: null,
  loading: false,
  errorEvaluateEndpoint: null,
  saving: false,
  validationError: null,
  showingLogprobs: showingLogprobs,
  pastVersions: {},
  sessions: {},
  selectedSession: {},
  versionItems: {},
  inputsCells: {},
  inferenceCells: {},
  openInLabData: {},
  sessionComments: {},
};

interface LabStoreState extends LabState {
  setError: (error: string | null) => void;
  setLoading: (loading: boolean) => void;
  setSaving: (saving: boolean) => void;
  setShowingLogprobs: (showingLogprobs: boolean) => void;
  setValidationError: (validationError: string | null) => void;
  setPastVersions: (pastVersions: VersionObj) => void;
  updatePastVersion: (id: number, version: VersionObj) => void;
  setSelectedSession: (selectedSession: SessionObj) => void;
  setSessions: (sessions: SessionObj) => void;
  updateSession: (id: number, session: SessionObj) => void;
  setVersionItems: (versionItems: VersionItemsObj) => void;
  addVersionItem: (id: number, sessionId: number, versionItem: any) => void;
  addAssignedFunction: (id: number) => void;
  addToAssignedFunctionsSet: (id: number, funcId: number) => void;
  removeFromAssignedFunctionsSet: (id: number, funcId: number) => void;
  deleteVersionItem: (id: number) => void;
  updateVersionItem: (id: number, versionItem: any) => void;
  setFunctionCall: (id: number, functionCall: string) => void;
  setVersionItemFunctionCall: (id: number, functionCall: string) => void;
  setInputsCells: (inputsCells: InputsCellsObj) => void;
  addInputsCell: (id: number, inputsCell?: InputsCellType) => void;
  deleteInputsCell: (id: number) => void;
  updateInputsCellValue: (id: number, values: string[]) => void;
  updateInputsCellTags: (id: number, tags: string[]) => void;
  setSingleInputsCellWithTestCase: (
    id: number,
    updatedInputsCell: InputsCellType
  ) => void;
  setInferenceCells: (inferenceCells: InferenceObj) => void;
  setInferenceCellLoading: (
    columnId: number,
    rowID: number,
    loading: boolean
  ) => void;
  setInferenceCellEvalLoading: (
    columnId: number,
    rowID: number,
    evalLoading: boolean
  ) => void;
  updateInferenceCell: (
    columnId: number,
    rowID: number,
    inference: Inference
  ) => void;
  clearInferenceCellStats: (columnId: number, rowID: number) => void;
  addInferenceCell: (
    columnId: number,
    rowID: number,
    inference: Inference
  ) => void;
  deleteInferenceCell: (columnId: number, rowID: number) => void;
  setOpenInLabData: (openInLabData: any) => void;
  addSessionComment: (
    sessionId: number,
    commentId: number,
    comment: LabSessionComment
  ) => void;
  deleteSessionComment: (sessionId: number, commentId: number) => void;
  setSessionComments: (sessionId: number, comments: CommentObj) => void;
}

export const useLabStore: () => LabStoreState = create(
  immer((set) => ({
    ...initialState,
    setError: (error: string | null) =>
      set((state) => {
        state.error = error;
      }),
    setErrorEvaluateEndpoint: (error: string | null) =>
      set((state) => {
        state.errorEvaluateEndpoint = error;
      }),
    setLoading: (loading: boolean) =>
      set((state) => {
        state.loading = loading;
      }),
    setSaving: (saving: boolean) =>
      set((state) => {
        state.saving = saving;
      }),
    setShowingLogprobs: (showingLogprobs: boolean) =>
      set((state) => {
        state.showingLogprobs = showingLogprobs;
      }),
    setValidationError: (validationError: string | null) =>
      set((state) => {
        state.validationError = validationError;
      }),
    setPastVersions: (pastVersions: VersionObj) =>
      set((state) => {
        state.pastVersions = pastVersions;
      }),
    updatePastVersion: (id, version) =>
      set((state) => {
        state.pastVersions[id] = { ...state.pastVersions[id], ...version };
      }),
    setSelectedSession: (selectedSession: SessionObj) =>
      set((state) => {
        state.selectedSession = selectedSession;
      }),
    setSessions: (sessions: SessionObj) =>
      set((state) => {
        state.sessions = sessions;
      }),
    updateSession: (id, session) =>
      set((state) => {
        state.sessions[id] = session;
      }),
    setVersionItems: (versionItems: VersionItemsObj) =>
      set((state) => {
        state.versionItems = versionItems;
      }),
    addVersionItem: (id, sessionId, versionItem) =>
      set((state) => {
        state.versionItems[id] = {
          ...versionItem,
          selected_session_id: sessionId,
        };
      }),
    addAssignedFunction: (id) =>
      set((state) => {
        if (!state.versionItems[id]?.function_ids) {
          state.versionItems[id].function_ids = new Set();
        }
      }),
    addToAssignedFunctionsSet: (id, funcId) =>
      set((state) => {
        if (!state.versionItems[id]?.function_ids) {
          state.versionItems[id].function_ids = new Set();
        }
        state.versionItems[id].function_ids.add(funcId);
      }),
    removeFromAssignedFunctionsSet: (id, funcId) =>
      set((state) => {
        if (
          state.versionItems[id]?.function_ids &&
          state.versionItems[id]?.function_ids.has(funcId)
        ) {
          state.versionItems[id].function_ids.delete(Number.parseInt(funcId));
        }
      }),
    deleteVersionItem: (id) =>
      set((state) => {
        delete state.versionItems[id];
      }),
    updateVersionItem: (id, versionItem) =>
      set((state) => {
        state.versionItems[id] = { ...state.versionItems[id], ...versionItem };
      }),
    setFunctionCall: (id, functionCall: string) =>
      set((state) => {
        state.versionItems[id].function_call = functionCall;
      }),
    setVersionItemFunctionCall: (id, functionCall) =>
      set((state) => {
        state.versionItems[id].function_call = functionCall;
      }),
    setInputsCells: (inputsCells: InputsCellsObj) =>
      set((state) => {
        state.inputsCells = inputsCells;
      }),
    addInputsCell: (id, inputsCell: InputsCellType = null) =>
      set((state) => {
        const newId = id || incrementId(state.inputsCells);
        const keys = Object.values(state.inputsCells)[0]?.keys || [''];
        const testCaseId = inputsCell?.testCaseId || null;
        state.inputsCells[newId] = inputsCell || {
          keys: keys,
          values: keys.map(() => ''),
          testCaseId: testCaseId,
          tags: inputsCell?.tags || [],
          collectionName: inputsCell?.collectionName || null,
          collectionId: inputsCell?.collectionId || null,
        };
      }),
    deleteInputsCell: (id: number) =>
      set((state) => {
        delete state.inputsCells[id];
      }),
    updateInputsCellValue: (id: number, values: string[]) =>
      set((state) => {
        state.inputsCells[id].values = values;
      }),
    updateInputsCellTags: (id: number, tags: string[]) =>
      set((state) => {
        state.inputsCells[id].tags = tags;
      }),
    setSingleInputsCellWithTestCase: (
      id: number,
      updatedInputsCell: InputsCellType
    ) =>
      set((state) => {
        state.inputsCells[id] = { ...updatedInputsCell };
      }),
    setInferenceCells: (inferenceCells: InferenceObj) =>
      set((state) => {
        state.inferenceCells = {
          ...state.inferenceCells,
          ...inferenceCells,
        };
      }),
    setInferenceCellLoading: (columnId, rowID, loading: boolean) =>
      set((state) => {
        if (state.inferenceCells[columnId][rowID]) {
          state.inferenceCells[columnId][rowID].loading = loading;
        }
      }),
    setInferenceCellEvalLoading: (columnId, rowID, evalLoading: boolean) =>
      set((state) => {
        if (state.inferenceCells[columnId][rowID]) {
          state.inferenceCells[columnId][rowID].evalLoading = evalLoading;
        }
      }),
    updateInferenceCell: (columnId, rowID, inference: Inference) =>
      set((state) => {
        state.inferenceCells[columnId][rowID] = {
          ...state.inferenceCells[columnId][rowID],
          ...inference,
        };
      }),
    clearInferenceCellStats: (columnId, rowID) =>
      set((state) => {
        if (state.inferenceCells[columnId][rowID]) {
          state.inferenceCells[columnId][rowID] = {
            ...state.inferenceCells[columnId][rowID],
            latency: null,
            cost: null,
            completion_tokens: null,
            content: '',
            scores: null,
            feedback: null,
            perplexity: null,
          };
        }
      }),
    addInferenceCell: (columnId, rowID, inference: Inference) =>
      set((state) => {
        if (!state.inferenceCells[columnId]) {
          state.inferenceCells[columnId] = {};
        }
        state.inferenceCells[columnId][rowID] = inference;
      }),
    deleteInferenceCell: (columnId, rowID) =>
      set((state) => {
        delete state.inferenceCells[columnId][rowID];
      }),
    setOpenInLabData: (openInLabData) =>
      set((state) => {
        state.openInLabData = openInLabData;
      }),
    addSessionComment: (sessionId, commentId, comment) =>
      set((state) => {
        if (!state.sessionComments[sessionId]) {
          state.sessionComments[sessionId] = {};
        }
        state.sessionComments[sessionId][commentId] = comment;
      }),
    deleteSessionComment: (sessionId, commentId) =>
      set((state) => {
        delete state.sessionComments[sessionId][commentId];
      }),
    setSessionComments: (sessionId, comments) =>
      set((state) => {
        if (!state.sessionComments[sessionId]) {
          state.sessionComments[sessionId] = {};
        }
        state.sessionComments[sessionId] = {
          ...state.sessionComments[sessionId],
          ...comments,
        };
      }),
  }))
);

export const useLabStoreState = (): LabStoreState =>
  useLabStore((state) => ({
    versionItems: state.versionItems,
    setError: state.setError,
    setLoading: state.setLoading,
    pastVersions: state.pastVersions,
    updateVersionItem: state.updateVersionItem,
    updatePastVersion: state.updatePastVersion,
    updateInputsCellValue: state.updateInputsCellValue,
    inputsCells: state.inputsCells,
    updateInferenceCell: state.updateInferenceCell,
    inferenceCells: state.inferenceCells,
    updateExampleCellValue: state.updateExampleCellValue,
    setInferenceCellLoading: state.setInferenceCellLoading,
    setInferenceCellEvalLoading: state.setInferenceCellEvalLoading,
    updateSession: state.updateSession,
    setVersionItems: state.setVersionItems,
    setSaving: state.setSaving,
    selectedSession: state.selectedSession,
    setSelectedSession: state.setSelectedSession,
    assignedFunctions: state.assignedFunctions,
    addToAssignedFunctionsSet: state.addToAssignedFunctionsSet,
    clearInferenceCellStats: state.clearInferenceCellStats,
    addSessionComment: state.addSessionComment,
    deleteSessionComment: state.deleteSessionComment,
    setSessionComments: state.setSessionComments,
  }));
