Files
SDK_SG200x_V2/cviruntime/samples_inner/mt/mt_model.hpp
carbon e25f20f7a3 add cviruntime
commit 3f4938648950a7f3bf9a19c320ca9fae7c52de20
Author: sophgo-forum-service <forum_service@sophgo.com>
Date:   Mon May 13 13:44:23 2024 +0800

    [feat] cviruntime opensource for cv18xx soc.

    - a4b6a3, add cumsum and gatherelements_pt.
2024-05-31 11:51:34 +08:00

114 lines
2.2 KiB
C++

#ifndef __SAMPLES_MT_MODEL_HPP
#define __SAMPLES_MT_MODEL_HPP
#include <stdio.h>
#include <math.h>
#include <time.h>
#include <stdlib.h>
#include <string.h>
#include <sys/time.h>
#include <opencv2/opencv.hpp>
#include "cviruntime.h"
#include "cnpy.h"
#define SOS_IDX 1
#define LEXICON_SIZE 16002
#define PAD_IDX 0
#define SOS_IDX 1
#define EOS_IDX 2
#define INFER_FIX_LEN 40
typedef uint16_t bf16_t;
class Encoder {
public:
Encoder(const char *model_file);
~Encoder() {
if (model) {
CVI_NN_CleanupModel(model);
}
}
bf16_t* run(int16_t *seq, int32_t size);
bf16_t* get_mask();
public:
CVI_MODEL_HANDLE model = nullptr;
CVI_TENSOR *src_seq;
CVI_TENSOR *src_mask;
CVI_TENSOR *enc_output;
private:
void gen_src_mask(int16_t *src_seq, int32_t size);
CVI_TENSOR *input_tensors;
CVI_TENSOR *output_tensors;
int32_t input_num;
int32_t output_num;
};
class Decoder {
public:
Decoder(CVI_MODEL_HANDLE model, int32_t max_step);
~Decoder() {
if (model) {
CVI_NN_CleanupModel(model);
}
}
int16_t run(int step, int16_t *seq,
bf16_t *enc, bf16_t *mask);
public:
CVI_TENSOR *trg_seq;
CVI_TENSOR *trg_mask;
CVI_TENSOR *enc_output;
CVI_TENSOR *src_mask;
CVI_TENSOR *dec_output;
int32_t max_step;
int32_t width;
private:
void gen_trg_mask();
int16_t argmax(int32_t step);
CVI_MODEL_HANDLE model = nullptr;
CVI_TENSOR *input_tensors;
CVI_TENSOR *output_tensors;
int32_t input_num;
int32_t output_num;
};
class MTrans {
public:
MTrans(const char *cvimodel) {
encoder = new Encoder(cvimodel);
decoder_0 = new Decoder(encoder->model, 0);
decoder_10 = new Decoder(encoder->model, 10);
decoder_20 = new Decoder(encoder->model, 20);
decoder_30 = new Decoder(encoder->model, 30);
decoder_39 = new Decoder(encoder->model, 39);
}
~MTrans() {
delete encoder;
delete decoder_0;
delete decoder_10;
delete decoder_20;
delete decoder_30;
delete decoder_39;
}
void run(int16_t *seq, int32_t seq_sz,
int16_t *gen_seq, int32_t gen_seq_sz);
private:
Encoder *encoder;
Decoder *decoder_0;
Decoder *decoder_10;
Decoder *decoder_20;
Decoder *decoder_30;
Decoder *decoder_39;
};
#endif