// import * as tf from "@tensorflow/tfjs";
import { GraphModel, LayersModel, Rank, Tensor, tensor } from '@tensorflow/tfjs'
import { useModel, LOADED } from './useModel'

const parse_segmentation_prediction = (argMaxPredict: number[][] | null, scale = 1) => {
  if (!argMaxPredict) return []
  const result = []
  let start = 0
  let last = 0
  let started = false
  for (const [indx, value] of argMaxPredict[0].entries()) {
    if (value === 1 && !started) {
      start = indx
      started = true
    } else if (started && value !== 1) {
      last = indx - 1
      started = false
      result.push([Math.ceil(start * scale), Math.ceil(last * scale)])
    } else if (started && argMaxPredict[0].length === indx + 1) {
      // Garante que se o array terminar em 1 irá adicionar como final de segmento.
      result.push([Math.ceil(start * scale), Math.ceil((indx - 1) * scale)])
    }
  }
  return result
}

type UseSegmentationModelReturn = [string, (kpts: number[][][], axis?: number) => Promise<number[][] | undefined>]

/**
 * Hook que utiliza o modelo de AutoSegmentação
 * @param {*} onResults
 * @returns
 */
const useSegmentationModel = (workspaceId: string): UseSegmentationModelReturn => {
  const modelUrl = `workspaces/${workspaceId}/models/segmentation/model.json`

  const [status, modelData] = useModel({
    modelUrl,
  })

  /**
   *
   * @param {*} mpResultsList Uma lista com resultados de Media Pipe. O tamanho da lista depende
   * da configuração do componente MediaPipeBlock (entrada do media pipe). O padrão é 15.
   * @returns
   */
  const predict = async (kpts: number[][][], axis?: number) => {
    // Aguarda o modelo carregar.
    const promise = new Promise<number[][]>((resolve, reject) => {
      const _axis = axis === undefined ? -1 : axis
      if (status !== LOADED || !modelData || !modelData.model) {
        console.warn('O modelo não está carregado.', status)
        reject()
      }

      // let k = 1;
      let data_tensor: Tensor<Rank> = tensor(kpts)
      // modelo utilizado é antigo por isso multiplicamos por -1
      // const aux = scalar(-1)
      // data_tensor = data_tensor.mul(aux)
      data_tensor = data_tensor.expandDims(0)

      modelData
        ?.getModelType()
        .then((modelFormat) => {
          return modelFormat
        })
        .then((modelFormat) => {
          if (modelFormat === 'GraphModel') {
            return (modelData.model as GraphModel).executeAsync(data_tensor)
          } else {
            return new Promise<Tensor<Rank> | Tensor<Rank>[]>((resol) => {
              const prediction = (modelData.model as LayersModel).predict(data_tensor)
              resol(prediction)
            })
          }
        })
        .then((_predictions) => {
          const predictions = _predictions as Tensor<Rank>
          return predictions.argMax(_axis).array()
        })
        .then((_predictionsArray) => {
          const predictionsArray = _predictionsArray as number[][]
          let segments: number[][] = []
          segments = segments.concat(parse_segmentation_prediction(predictionsArray))
          if (segments.length === 0) segments = []
          resolve(segments)
        })
        .finally(() => {
          // Libera memória
          data_tensor.dispose()
        })
    })
    return promise
  }

  return [status, predict]
}

export default useSegmentationModel
