Files
SDK_SG200x_V2/cvikernel/src/bm1822/tiu_matrix_multiplication.c
carbon 88a2fed916 add cvikernel
commit 9f1f57a19c3c281a931dfc71b318494487193d56
Author: sophgo-forum-service <forum_service@sophgo.com>
Date:   Mon May 13 13:58:23 2024 +0800

    [feat] cvikernel opensource for cv18xx soc.

    - 79b6a7, set lookup_interp_table layer_id.
2024-05-31 11:46:37 +08:00

152 lines
3.9 KiB
C

#include "kernel_1822.h"
typedef bmk1822_tiu_matrix_multiplication_param_t param_t;
static void check_matrix(ctx_t *ctx, const ml_t *m)
{
bmk1822_tensor_lmem_t t;
t.start_address = m->start_address;
t.fmt = m->fmt;
t.shape.n = m->shape.n;
t.shape.c = m->shape.c;
t.shape.h = 1;
t.shape.w = m->shape.w;
t.stride.n = m->stride.n;
t.stride.c = m->stride.c;
t.stride.h = m->stride.h;
t.stride.w = 1 * (m->fmt == FMT_BF16 ? 2 : 1);
check_tiu_tensor(&t);
assert_stride_type_0(ctx, &t);
uint32_t eu_num = ctx->chip_info.eu_num;
ASSERT(m->start_address % eu_num == 0);
}
static int is_arith_shift(const param_t *p)
{
if (p->left->fmt == FMT_I8)
return 1;
if (p->right->fmt == FMT_I8)
return 1;
if (p->bias && p->bias->fmt == FMT_I8)
return 1;
return 0;
}
bmk1822_op_t * bmk1822_tiu_matrix_multiplication(ctx_t *ctx, const param_t *p)
{
const bmk1822_matrix_lmem_t *res = p->res;
const bmk1822_matrix_lmem_t *left = p->left;
const bmk1822_matrix_lmem_t *right = p->right;
const bmk1822_matrix_lmem_t *bias = p->bias;
int bf16_enable = (res->fmt == FMT_BF16) ? 1 : 0;
check_matrix(ctx, res);
check_matrix(ctx, left);
check_matrix(ctx, right);
if (bias)
check_matrix(ctx, bias);
ASSERT(p->lshift_bits < 32);
if (bf16_enable) /* bf16 does not support add_result*/
ASSERT(!p->add_result);
else
ASSERT(!(p->relu_enable && p->add_result));
if(p->ps32_mode & 0x2)
{
ASSERT(!p->relu_enable);
ASSERT(!p->bias);
ASSERT(!p->rshift_bits);
}
ASSERT(p->relu_enable == 0 || p->relu_enable == 1);
uint32_t left_row = left->shape.n;
uint32_t left_col = left->shape.col;
uint32_t right_row = right->shape.n;
uint32_t right_col = right->shape.col;
uint32_t res_row = res->shape.n;
uint32_t res_col = res->shape.col;
ASSERT(left_col == right_row);
ASSERT(res_col == right_col);
if(p->ps32_mode)
{
ASSERT(!p->add_result);
} else if ((p->add_result || !p->res_is_int8) && !bf16_enable) {
ASSERT(res_row == left_row * 2);
res_row = left_row;
} else {
ASSERT(res_row == left_row);
}
tiu_reg_t reg;
reset_tiu_reg(&reg);
reg.cmd_en = 1;
reg.tsk_typ = DCR_TYPE_FC_FIX8B;
reg.tsk_opd_num = bias? 3: 2;
reg.opd_typ = bf16_enable ? 1 : 0;
reg.opt_shift_typ = is_arith_shift(p);
reg.opt_res_shift = p->rshift_bits;
reg.opt_left_shift = p->lshift_bits;
reg.opt_relu_typ = p->relu_enable;
reg.opt_res_add = p->add_result;
reg.res0_addr = res->start_address;
reg.opt_res0_seg = (bf16_enable ? 1 : p->res_is_int8);
reg.opt_res0_sign = matrix_is_signed(res);
reg.res0_n = res_row;
reg.res0_c = res->shape.c;
reg.res0_h = 1;
reg.res0_w = res->shape.w;
reg.short_res0_str = 0; // stride, b_stride calculated by H/W
reg.opd0_addr = left->start_address;
reg.opt_opd0_seg = 1;
reg.opt_opd0_sign = (left->fmt == FMT_I8);
reg.opd0_n = left_row;
reg.opd0_c = left->shape.c;
reg.opd0_h = 1;
reg.opd0_w = left->shape.w;
reg.short_opd0_str = 0;
reg.opd1_addr = right->start_address;
reg.opt_opd1_seg = 1;
reg.opt_opd1_sign = (right->fmt == FMT_I8);
reg.opd1_n = right_row;
reg.opd1_c = right->shape.c;
reg.opd1_h = 1;
reg.opd1_w = left_col - left->shape.w * (left->shape.c - 1);
reg.short_opd1_str = 0;
reg.ps32_md = p->ps32_mode;
if (p->ps32_mode > 0)
reg.res0_b_str = p->res->shape.n * p->res->stride.n;
if(reg.opd0_c == 1)
ASSERT(reg.opd0_w == reg.opd1_w);
if (bias) {
ASSERT(bias->shape.n == 2);
ASSERT(bias->shape.c == right->shape.c);
ASSERT(bias->shape.w == right->shape.w);
ASSERT(bias->shape.col == right->shape.col);
reg.opd2_addr = bias->start_address;
reg.opt_opd2_seg = 0;
reg.opt_opd2_sign = (bias->fmt == FMT_I8);
reg.opd2_n = 1;
reg.opd2_c = bias->shape.c;
reg.opd2_h = 1;
reg.opd2_w = bias->shape.w;
reg.short_opd2_str = 0;
}
reg.layer_info = p->layer_id;
return emit_tiu_cmdbuf(ctx, &reg);
}