import { create } from 'zustand';
import { immer } from 'zustand/middleware/immer';
import type { LogProbability } from '../utils/types';

type InferenceStats = {
  tokens: number,
  tokens_prompt: number,
  tokens_completion: number,
  cost: number,
  latency: number,
  perplexity: number | null,
};

type StreamMessage = {
  content: string,
  done: boolean,
  start: DOMHighResTimeStamp,
  end: DOMHighResTimeStamp,
  statsSet: boolean,
  inferenceId: string | null,
  logprobs: LogProbability[] | null,
};

type StreamingState = {
  error: string | null,
  loading: { [key: string]: boolean },
  messages: { [key: string]: StreamMessage },
  stats: { [key: string]: InferenceStats },
};

const initialState: StreamingState = {
  error: null,
  loading: {},
  messages: {},
  stats: {},
};

export const useStreamingStore = create(
  immer((set) => ({
    ...initialState,
    setError: (error: string | null) =>
      set((state) => {
        state.error = error;
      }),
    setLoading: (id: string, loading: boolean) =>
      set((state) => {
        state.loading[id] = loading;
      }),
    setMessage: (
      id: string,
      content: string,
      logprobs: LogProbability[] = []
    ) =>
      set((state) => {
        state.messages[id] = {
          ...state.messages[id],
          content,
          logprobs,
          done: false,
          statsSet: false,
        };
      }),
    setInferenceId: (id: string, infId: string) =>
      set((state) => {
        state.messages[id] = {
          ...state.messages[id],
          inferenceId: infId,
        };
      }),
    setStreamStart: (id: string, start: number) =>
      set((state) => {
        state.messages[id] = {
          ...state.messages[id],
          start,
        };
      }),
    setStreamDone: (id: string, end: number) =>
      set((state) => {
        state.messages[id] = {
          ...state.messages[id],
          done: true,
          end,
        };
      }),
    setMessageStatsSet: (id: string) =>
      set((state) => {
        state.messages[id] = {
          ...state.messages[id],
          statsSet: true,
        };
      }),
    setStats: (id: string, stats: InferenceStats) =>
      set((state) => {
        state.stats[id] = stats;
      }),
    clearAll: () =>
      set((state) => {
        state.error = null;
        state.loading = {};
        state.messages = {};
        state.stats = {};
        state.logprobs = [];
      }),
    clearById: (id: string) =>
      set((state) => {
        delete state.loading[id];
        delete state.messages[id];
        delete state.stats[id];
        delete state.logprobs[id];
      }),
  }))
);
