#ifndef __SAMPLES_MT_MODEL_HPP #define __SAMPLES_MT_MODEL_HPP #include #include #include #include #include #include #include #include "cviruntime.h" #include "cnpy.h" #define SOS_IDX 1 #define LEXICON_SIZE 16002 #define PAD_IDX 0 #define SOS_IDX 1 #define EOS_IDX 2 #define INFER_FIX_LEN 40 typedef uint16_t bf16_t; class Encoder { public: Encoder(const char *model_file); ~Encoder() { if (model) { CVI_NN_CleanupModel(model); } } bf16_t* run(int16_t *seq, int32_t size); bf16_t* get_mask(); public: CVI_MODEL_HANDLE model = nullptr; CVI_TENSOR *src_seq; CVI_TENSOR *src_mask; CVI_TENSOR *enc_output; private: void gen_src_mask(int16_t *src_seq, int32_t size); CVI_TENSOR *input_tensors; CVI_TENSOR *output_tensors; int32_t input_num; int32_t output_num; }; class Decoder { public: Decoder(CVI_MODEL_HANDLE model, int32_t max_step); ~Decoder() { if (model) { CVI_NN_CleanupModel(model); } } int16_t run(int step, int16_t *seq, bf16_t *enc, bf16_t *mask); public: CVI_TENSOR *trg_seq; CVI_TENSOR *trg_mask; CVI_TENSOR *enc_output; CVI_TENSOR *src_mask; CVI_TENSOR *dec_output; int32_t max_step; int32_t width; private: void gen_trg_mask(); int16_t argmax(int32_t step); CVI_MODEL_HANDLE model = nullptr; CVI_TENSOR *input_tensors; CVI_TENSOR *output_tensors; int32_t input_num; int32_t output_num; }; class MTrans { public: MTrans(const char *cvimodel) { encoder = new Encoder(cvimodel); decoder_0 = new Decoder(encoder->model, 0); decoder_10 = new Decoder(encoder->model, 10); decoder_20 = new Decoder(encoder->model, 20); decoder_30 = new Decoder(encoder->model, 30); decoder_39 = new Decoder(encoder->model, 39); } ~MTrans() { delete encoder; delete decoder_0; delete decoder_10; delete decoder_20; delete decoder_30; delete decoder_39; } void run(int16_t *seq, int32_t seq_sz, int16_t *gen_seq, int32_t gen_seq_sz); private: Encoder *encoder; Decoder *decoder_0; Decoder *decoder_10; Decoder *decoder_20; Decoder *decoder_30; Decoder *decoder_39; }; #endif