import * as tf from '@tensorflow/tfjs';

export class PredictionModel {
  model: tf.LayersModel;
  labelEncoder: string[];
  vocabularyEncoder: string[];

  constructor(model: tf.LayersModel, labelEncoder: string[], vocabularyEncoder: string[]) {
    this.model = model;
    this.labelEncoder = labelEncoder;
    this.vocabularyEncoder = vocabularyEncoder;
  }

  predictCategory(input: string) {
    const tensor = this.getTensorFromSentence(input);
    const predictions = (this.model.predict(tensor) as tf.Tensor).dataSync();
    const resultIdx = tf.argMax(predictions).dataSync()[0];
    return this.labelEncoder[resultIdx];
  }

  getTensorFromSentence(input: string) {
    let array = [];
    let indexes = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
    if (!!input) {
      let ind = 0;
      input = input.replace(/[-]/g, ' ');
      input = input.replace(/[#$%&'()*+,-./:;<=>?@[\]^_`{|}~!]/g, '');
      input = input.toLowerCase();
      for (const word of input.split(' ')) {
        const dictIndex = this.vocabularyEncoder.findIndex((x) => x === word);
        indexes[ind++] = dictIndex;
      }
    }
    array.push(indexes);
    return tf.tensor2d(array);
  }
}
