import { loadLayersModel, LayersModel, loadGraphModel, GraphModel } from '@tensorflow/tfjs'
import { LoadOptions } from '@tensorflow/tfjs-core/dist/io/types'
import { useEffect, useState } from 'react'
import { getDownloadURL, ref } from 'firebase/storage'
import { useStorage } from 'reactfire'
import * as tf from '@tensorflow/tfjs'
import * as tfWasm from '@tensorflow/tfjs-backend-wasm'

// Tipos de status válidos
export const INITIALIZING = 'INITIALIZING'
export const LOADING = 'LOADING'
export const LOADED = 'LOADED'
export const ERROR = 'ERROR'

interface IUseModelProps {
  modelUrl: string
  options?: LoadOptions
}

interface IModelData {
  model: LayersModel | null | GraphModel
  getModelType: () => Promise<'LayersModel' | 'GraphModel'>
  modelUrl: string
}

type IUseModelType = [string, IModelData | null]

export const useModel = ({ modelUrl, options }: IUseModelProps): IUseModelType => {
  const [status, setStatus] = useState(INITIALIZING)
  const [modelData, setModelData] = useState<IModelData | null>(null)
  const storage = useStorage()

  // Erro, não há url.
  if (!modelUrl) {
    console.error('URL do modelo é obrigatória.')
    setStatus(ERROR)
  }

  const workerFunc = async () => {
    tfWasm.setWasmPaths('https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-backend-wasm/wasm-out/')
    await tf.setBackend('wasm')

    return true
  }

  // Carregamos o modelo
  useEffect(() => {
    if (modelUrl) {
      workerFunc().then(() => {
        getModelType().then((modelFormat) => {
          setStatus(LOADING)
          const _options = { ...options }

          _options.fetchFunc = async (url: string) => {
            const urlDownload = await getDownloadURL(ref(storage, url))
            return fetch(urlDownload, { cache: 'force-cache' })
          }

          if (modelFormat === 'GraphModel') {
            loadGraphModel(modelUrl, _options)
              .then((loadedModel) => {
                console.log(`Modelo ${modelUrl} carregado.`)
                setModelData({
                  model: loadedModel,
                  getModelType: getModelType,
                  modelUrl: modelUrl,
                })
                setStatus(LOADED)
              })
              .catch((err) => {
                console.log(err.message)
              })
          } else {
            loadLayersModel(modelUrl, _options)
              .then((loadedModel) => {
                console.log(`Modelo ${modelUrl} carregado.`)
                setModelData({
                  model: loadedModel,
                  getModelType: getModelType,
                  modelUrl: modelUrl,
                })
                setStatus(LOADED)
              })
              .catch((err) => {
                console.log(err)
              })
          }
        })
      })
    }

    // retornamos uma função para deletá-lo quando o componente for desmontado.
    return () => {
      if (modelData && modelData.model) modelData.model.dispose()
    }
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [])

  const getModelType = async () => {
    if (modelData?.model) {
      if ((modelData.model as LayersModel).getLayer) {
        return 'LayersModel'
      } else {
        return 'GraphModel'
      }
    } else {
      const urlDownload = await getDownloadURL(ref(storage, modelUrl))
      const res = await fetch(urlDownload)
      const modelInfo = await res.json()
      return modelInfo.format === 'graph-model' ? 'GraphModel' : 'LayersModel'
    }
  }

  return [status, modelData]
}
