import { tensor, Tensor, Rank, GraphModel } from '@tensorflow/tfjs'
import { WorkspaceId } from 'collections'
import { getDownloadURL, ref } from 'firebase/storage'
import { useEffect, useRef } from 'react'
import { useStorage } from 'reactfire'
import { useModel, LOADED } from './useModel'
import { IReturnSignRecognition } from '../types/keypoints'

export interface IPredictSignRecognitionProps {
  kpts: number[][][]
  segments: number[][]
  topk?: number
}

export interface IPredictSignRecognitionReturn {
  classes: {
    id: string
    glosa?: string
  }[]
  metrics: number[]
}

type UseSignRecoginitionModelFNS = [
  string,
  (prop: IPredictSignRecognitionProps) => Promise<IPredictSignRecognitionReturn[] | undefined>,
]

export const useSignRecognitionModel = (workspace: WorkspaceId): UseSignRecoginitionModelFNS => {
  const classesUrl = `/workspaces/${workspace}/models/tfjs/classes_sign.json`
  const modelUrl = `/workspaces/${workspace}/models/tfjs/model.json`
  // Esse item guarda o status de carregamento do modelo:
  const [status, modelData] = useModel({ modelUrl })
  const classes = useRef<Record<string, string> | null>(null)
  const storage = useStorage()

  const predict = async ({ kpts, segments, topk = 5 }: IPredictSignRecognitionProps) => {
    // Aguarda o modelo carregar.
    if (status !== LOADED || !modelData || !modelData.model) {
      console.warn('O modelo não está carregado.')
      return
    }

    if (!kpts || kpts.length === 0) return

    const keypoints: number[][][][] = []

    // Criamos um array com os keypoints de cada segmento, reescalando
    for (const seg of segments) {
      if (seg[1] - seg[0] > 1) {
        const tensorTemp = (tensor(kpts.slice(seg[0], seg[1])) as Tensor<Rank>).resizeNearestNeighbor([15, 59])
        keypoints.push((await tensorTemp.array()) as number[][][])
        tensorTemp.dispose()
      }
    }

    // Realizamos a predição utilizando o modelo de reconhecimento de sinais

    const predict_return: IReturnSignRecognition[] = []

    let res: {
      values: Tensor<Rank>
      indices: Tensor<Rank>
    }
    const modelFormat = await modelData.getModelType()
    if (modelFormat === 'GraphModel') {
      let kp_tensor = tensor(keypoints)

      kp_tensor = kp_tensor.resizeNearestNeighbor([15, 59])

      res = ((await (modelData.model as GraphModel).executeAsync(kp_tensor)) as Tensor<Rank>).topk(topk)
      kp_tensor.dispose()
    } else {
      const kp_tensor = tensor(keypoints)

      res = (modelData.model.predict(kp_tensor) as Tensor<Rank>).topk(topk)
      kp_tensor.dispose()
    }

    const { values: valuesSnap, indices: indicesSnap } = res

    // Por fim, trocamos os valores das probabilidades pela classe.
    const values = await valuesSnap.data()
    const indices = await indicesSnap.data()

    for (let i = 0; i < keypoints.length; i++) {
      predict_return.push({
        classes: [...indices].slice(i * topk, i * topk + topk).map((tk) => {
          return {
            id: classes.current ? classes.current[tk] : '',
          }
        }),
        metrics: [...values].slice(i * topk, i * topk + topk),
      })
    }

    return predict_return
  }

  useEffect(() => {
    if (!classes.current) {
      downloadClasses()
    }
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [])

  const downloadClasses = async () => {
    try {
      if (classesUrl) {
        const url = await getDownloadURL(ref(storage, classesUrl))
        const data = await fetch(url)
        classes.current = await data.json()
      }
    } catch (err) {
      console.log(err)
    }
  }

  return [status, predict]
}

export default useSignRecognitionModel
