import { useRef } from 'react'
import * as tf from '@tensorflow/tfjs'
import { Shape, Tensor } from '@tensorflow/tfjs'
import { getDownloadURL, ref } from 'firebase/storage'
import { ref as refDatabase } from 'firebase/database'
import { useDatabase, useStorage } from 'reactfire'
import { getData } from 'collections'

interface IPredictProps {
  dataTensor: tf.Tensor<tf.Rank>
  range?: number
  axis?: number
}

interface IPredictPropsSiamese {
  dataTensor: tf.Tensor<tf.Rank>
  segTensor: tf.Tensor<tf.Rank>
  range?: number
  axis?: number
}

// custom layer da rede siamesa
class L2_Layer extends tf.layers.Layer {
  constructor() {
    super({})
    this.supportsMasking = true
    this.name = 'L2_Layer'
  }

  computeOutputShape(inputShape: Shape[]) {
    return [inputShape[0][0] || null, 1]
  }

  /**
   * Calcula a distância Euclidiana.
   */
  call(inputs: Tensor[]) {
    const [x, y] = inputs
    const sumSquared = tf.sum(tf.square(tf.sub(x, y)), 1, true)

    return tf.sqrt(tf.maximum(sumSquared, 1e-7))
  }

  static get className() {
    return 'Lambda'
  }
}

tf.serialization.registerClass(L2_Layer) // Needed for serialization.

export interface IUseModelFns {
  model: tf.LayersModel | null
  classes: Record<number, string> | null

  download: () => Promise<void>

  predict: (
    data: IPredictProps,
  ) => Promise<
    | tf.Tensor<tf.Rank>
    | string[]
    | number
    | number[]
    | number[][]
    | number[][][]
    | number[][][][]
    | number[][][][][]
    | number[][][][][][]
    | null
  >
  predictSiamese: (
    data: IPredictPropsSiamese,
  ) => Promise<
    | tf.Tensor<tf.Rank>
    | string[]
    | number
    | number[]
    | number[][]
    | number[][][]
    | number[][][][]
    | number[][][][][]
    | number[][][][][][]
    | Float32Array
    | Int32Array
    | Uint8Array
    | null
  >
}

interface IUseModelProps {
  modelPath: string
  classesPath?: string
  modelName?: string
  motionModel?: boolean
}

const useModel = ({ modelPath, classesPath, modelName, motionModel }: IUseModelProps): IUseModelFns => {
  const model = useRef<tf.LayersModel | null>(null)
  const classes = useRef<Record<number, string> | null>()
  const storage = useStorage()
  const db = useDatabase()

  const download = async () => {
    if (model.current) return
    if (motionModel) {
      const refDb = refDatabase(db, `modelsVersion/${modelName}`)
      const data = await getData(refDb)
      const version = data.val()
      const oldVersion = localStorage.getItem(`${modelName}ModelOldVersion`)

      try {
        const _model = await tf.loadLayersModel(`indexeddb://${modelName}`)
        // Verifica se a versão atual do modelo é diferente da versão no navegador
        if (oldVersion) {
          if (oldVersion != version) {
            const _model = await tf.loadLayersModel(`${modelPath}/v${version}/model.json`, {
              fetchFunc: async (url: string) => {
                const urlDownload = await getDownloadURL(ref(storage, decodeURIComponent(url)))
                return fetch(urlDownload, { cache: 'force-cache' })
              },
            })
            model.current = _model
            model.current.save(`indexeddb://${modelName}`)
            localStorage.setItem(`${modelName}ModelOldVersion`, version)
          } else {
            model.current = _model
          }
        } else {
          localStorage.setItem(`${modelName}ModelOldVersion`, version)
        }
      } catch {
        // Baixa modelo
        const _model = await tf.loadLayersModel(`${modelPath}/v${version}/model.json`, {
          fetchFunc: async (url: string) => {
            const urlDownload = await getDownloadURL(ref(storage, decodeURIComponent(url)))
            return fetch(urlDownload, { cache: 'force-cache' })
          },
        })
        model.current = _model
        if (model.current) {
          model.current.save(`indexeddb://${modelName}`)
        }
      }
      if (classesPath) {
        const url = await getDownloadURL(ref(storage, `${classesPath}/v${version}/classes.json`))
        const data = await fetch(url, { cache: 'force-cache' })
        classes.current = await data.json()
      }
    } else {
      const _model = await tf.loadLayersModel(modelPath, {
        fetchFunc: async (url: string) => {
          const urlDownload = await getDownloadURL(ref(storage, decodeURIComponent(url)))
          return fetch(urlDownload, { cache: 'force-cache' })
        },
      })
      model.current = _model

      // Baixa classes
      if (classesPath) {
        const url = await getDownloadURL(ref(storage, classesPath))
        const data = await fetch(url, { cache: 'force-cache' })
        classes.current = await data.json()
      }
    }
  }

  const predict = async (data: IPredictProps) => {
    if (!model.current) return null
    const { dataTensor, range, axis } = data

    const decode_token = classes.current
    const modelPredict = model.current?.predict(dataTensor) as tf.Tensor<tf.Rank>
    if (decode_token) {
      // eslint-disable-next-line @typescript-eslint/no-unused-vars
      const { values, indices } = modelPredict.topk(range)
      // const valuesData = await values.data()
      const indicesData = await indices.data()
      return [...indicesData].map((tk) => {
        return decode_token[tk]
      })
    } else {
      const argMax = tf.argMax(modelPredict, axis)
      return await argMax.array()
    }
  }

  const predictSiamese = async (data: IPredictPropsSiamese) => {
    if (!model.current) return null
    const { dataTensor, segTensor } = data
    const modelPredict = model.current?.predict([segTensor, dataTensor]) as tf.Tensor<tf.Rank>
    //
    const result = await modelPredict.data()
    return result
  }

  return {
    model: model.current || null,
    download,
    classes: classes.current || null,
    predict,
    predictSiamese,
  }
}

export default useModel
