asr_prefetch.c 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. #include "include.h"
  2. #include "asr.h"
  3. #if ASR_PREFETCH_EN && ASR_EN
  4. ///x0,sum矩阵只可以使用ram0(0x50000~0x54000)和ram1(0x54000~0x59000),y0矩阵可以使用weight0-ram1(0x40000~~0x59000),且均要4字节对齐
  5. ///并且x0和y0不能使用同一块ram
  6. void matrix_hw(int32_t* sum, int8_t* x0, int8_t *y0, s16 loop);
  7. void npu_memcpy_int16(int16_t *c, int16_t *a, int32_t a_size);
  8. void npu_matrix_kick_wait(void);
  9. #if NPU_CONTIN_CAL_EN
  10. AT(.ws_asr.ram0.sum)
  11. int32_t sum_buffer[MAX_NPU_MATRIX * 16];
  12. #else
  13. AT(.ws_asr.ram0.sum.s)
  14. int sum_buffer[0x1e8];
  15. AT(.ws_asr.ram0.sum)
  16. int32_t matrix_sum ;
  17. #endif
  18. AT(.ws_asr.offset)
  19. int8_t asr_offset[12]; //让asr库里数组对齐16字节
  20. AT(.npu_matrix.ram0.x)
  21. int8_t npu_matrix_x[256 * MAX_NPU_MATRIX];
  22. typedef struct {
  23. u16 all_frame;
  24. u16 frame_len;
  25. u8 load_frame;
  26. } tdnn_prenum_t;
  27. AT(.com_rodata.tdnn)
  28. const tdnn_prenum_t tdnn_prenum[7] = {
  29. {128, 120, 68},
  30. {128, 256, 32},
  31. {128, 256, 32},
  32. {128, 256, 32},
  33. {128, 256, 32},
  34. {128, 128, 64},
  35. {488, 128, 64},
  36. };
  37. typedef struct {
  38. u8 load_frame;
  39. u8 *buf;
  40. } tdnn_cache_t;
  41. typedef struct {
  42. u16 all_frame;
  43. u8 index;
  44. u8 toggle;
  45. u32 load_addr;
  46. tdnn_cache_t cache[2];
  47. } tdnn_pretetch_t;
  48. extern const int16_t tdnn_mean_1[128];
  49. extern const int16_t tdnn_bias_1[128];
  50. extern const float tdnn_var_1[128];
  51. extern const int16_t tdnn_mean_2[128];
  52. extern const int16_t tdnn_bias_2[128];
  53. extern const float tdnn_var_2[128];
  54. extern const int16_t tdnn_mean_3[128];
  55. extern const int16_t tdnn_bias_3[128];
  56. extern const float tdnn_var_3[128];
  57. extern const int16_t tdnn_mean_4[128];
  58. extern const int16_t tdnn_bias_4[128];
  59. extern const float tdnn_var_4[128];
  60. extern const int16_t tdnn_mean_5[128];
  61. extern const int16_t tdnn_bias_5[128];
  62. extern const float tdnn_var_5[128];
  63. extern const int16_t tdnn_mean_6[128];
  64. extern const int16_t tdnn_bias_6[128];
  65. extern const float tdnn_var_6[128];
  66. extern const int16_t tdnn_bias_7[488];
  67. extern const float tdnn_scale_7[488];
  68. int16_t tdnn_ram_mean1[128] AT(.ws_asr.test);
  69. int16_t tdnn_ram_bias1[128] AT(.ws_asr.test);
  70. float tdnn_ram_var1[128] AT(.ws_asr.test);
  71. int16_t tdnn_ram_mean2[128] AT(.ws_asr.test);
  72. int16_t tdnn_ram_bias2[128] AT(.ws_asr.test);
  73. float tdnn_ram_var2[128] AT(.ws_asr.test);
  74. int16_t tdnn_ram_mean3[128] AT(.ws_asr.test);
  75. int16_t tdnn_ram_bias3[128] AT(.ws_asr.test);
  76. float tdnn_ram_var3[128] AT(.ws_asr.test);
  77. int16_t tdnn_ram_mean4[128] AT(.ws_asr.test);
  78. int16_t tdnn_ram_bias4[128] AT(.ws_asr.test);
  79. float tdnn_ram_var4[128] AT(.ws_asr.test);
  80. int16_t tdnn_ram_mean5[128] AT(.ws_asr.test);
  81. int16_t tdnn_ram_bias5[128] AT(.ws_asr.test);
  82. float tdnn_ram_var5[128] AT(.ws_asr.test);
  83. int16_t tdnn_ram_mean6[128] AT(.ws_asr.test);
  84. int16_t tdnn_ram_bias6[128] AT(.ws_asr.test);
  85. float tdnn_ram_var6[128] AT(.ws_asr.test);
  86. int16_t tdnn_ram_bias7[488] AT(.ws_asr.test);
  87. float tdnn_ram_scale7[488] AT(.ws_asr.test);
  88. u8 tdnn_buffer[2][0x2000] AT(.npu_matrix.ram0.tdnn);
  89. tdnn_pretetch_t tdnn_pretetch AT(.ws_asr.test);
  90. u8 asr_prefetch_kisck;
  91. AT(.com_text.tdnn)
  92. void tdnn_compute_float_cal(float *out_buf, const int16_t *tdnn_mean, const float *tdnn_scale,
  93. const float *tdnn_var, const int16_t *tdnn_bias, float in_scale, u16 matrix_cnt_t, u16 float_cnt_t, int last_layer)
  94. {
  95. int32_t sum_curr = 0, z = 0;
  96. for (int cnt = float_cnt_t; cnt < (matrix_cnt_t + float_cnt_t); cnt++) {
  97. z = sum_buffer[(sum_curr<<4)] * in_scale + tdnn_bias[cnt];
  98. if (!last_layer) {
  99. if (z <= 0) {
  100. z = 0;
  101. }
  102. out_buf[cnt] = (z - tdnn_mean[cnt]) * tdnn_var[cnt];
  103. } else {
  104. out_buf[cnt] = z * tdnn_scale[cnt];
  105. }
  106. sum_curr++;
  107. }
  108. }
  109. AT(.com_text.tdnn)
  110. void tdnn_compute(int8_t *in_buf, float in_scale, int in_dim, int out_dim,
  111. const int8_t *tdnn_weight, const int16_t *tdnn_bias,
  112. const int16_t *tdnn_mean, const float *tdnn_var, const float *tdnn_scale,
  113. float *out_buf, int last_layer) {
  114. int i;
  115. if (asr_prefetch_kisck) {
  116. return;
  117. }
  118. tdnn_pretetch_t *t = &tdnn_pretetch;
  119. const tdnn_prenum_t *p = &tdnn_prenum[t->index];
  120. u32 frame_len = p->frame_len * p->load_frame;
  121. // printf("tdnn_weight:%x out_dim:%d in_dim:%d index:%d frame_len:%d\n", tdnn_weight, out_dim, in_dim, t->index, frame_len);
  122. if (tdnn_mean == tdnn_mean_1) {
  123. tdnn_mean = tdnn_ram_mean1;
  124. tdnn_bias = tdnn_ram_bias1;
  125. tdnn_var = tdnn_ram_var1;
  126. } else if (tdnn_mean == tdnn_mean_2) {
  127. tdnn_mean = tdnn_ram_mean2;
  128. tdnn_bias = tdnn_ram_bias2;
  129. tdnn_var = tdnn_ram_var2;
  130. } else if (tdnn_mean == tdnn_mean_3) {
  131. tdnn_mean = tdnn_ram_mean3;
  132. tdnn_bias = tdnn_ram_bias3;
  133. tdnn_var = tdnn_ram_var3;
  134. } else if (tdnn_mean == tdnn_mean_4) {
  135. tdnn_mean = tdnn_ram_mean4;
  136. tdnn_bias = tdnn_ram_bias4;
  137. tdnn_var = tdnn_ram_var4;
  138. } else if (tdnn_mean == tdnn_mean_5) {
  139. tdnn_mean = tdnn_ram_mean5;
  140. tdnn_bias = tdnn_ram_bias5;
  141. tdnn_var = tdnn_ram_var5;
  142. } else if (tdnn_mean == tdnn_mean_6) {
  143. tdnn_mean = tdnn_ram_mean6;
  144. tdnn_bias = tdnn_ram_bias6;
  145. tdnn_var = tdnn_ram_var6;
  146. } else {
  147. tdnn_scale = tdnn_ram_scale7;
  148. tdnn_bias = tdnn_ram_bias7;
  149. }
  150. // GPIOBSET |= BIT(4);
  151. spiflash_lock();
  152. // GPIOBSET = BIT(1);///
  153. #if NPU_CONTIN_CAL_EN
  154. u8 matrix_cnt = 0;
  155. u16 float_cnt = 0;
  156. memset(sum_buffer, 0, 4 * MAX_NPU_MATRIX * 16);
  157. #endif
  158. for (i = 0; i < out_dim; i++) {
  159. u16 offset = (i * in_dim) % frame_len;
  160. //load下一帧
  161. if (offset == 0) {
  162. u16 load_frame = 0;
  163. t->toggle ^= 1;
  164. if (t->all_frame == 0) {
  165. t->index++;
  166. if (t->index >= 7) {
  167. t->index = 0;
  168. }
  169. p = &tdnn_prenum[t->index];
  170. t->all_frame = p->all_frame;
  171. load_frame = p->load_frame;
  172. } else {
  173. if (t->all_frame >= p->load_frame) {
  174. load_frame = p->load_frame;
  175. } else {
  176. load_frame = t->all_frame;
  177. }
  178. }
  179. u32 load_len = load_frame*p->frame_len;
  180. spiflash_read_wait();
  181. // GPIOESET |= BIT(4);///
  182. spiflash_read_kick(t->cache[t->toggle ^ 1].buf, t->load_addr, load_len);
  183. // GPIOECLR |= BIT(4);///
  184. t->all_frame -= load_frame;
  185. t->load_addr += load_len;
  186. if (t->load_addr >= (ASR_BASE_ADDR + ASR_BASE_LEN)) {
  187. t->load_addr = ASR_BASE_ADDR;
  188. }
  189. }
  190. int8_t *tdnn_buffer_t = (int8_t *)t->cache[t->toggle].buf + offset;
  191. #if !NPU_CONTIN_CAL_EN
  192. matrix_sum = 0;
  193. u8 mod = ((int)tdnn_buffer_t % 16);
  194. if (mod) { //地址对齐16bytes
  195. memcpy(npu_matrix_x, tdnn_buffer_t, in_dim);
  196. matrix_hw(&matrix_sum, npu_matrix_x, in_buf, in_dim);
  197. } else {
  198. matrix_hw(&matrix_sum, tdnn_buffer_t, in_buf, in_dim);
  199. }
  200. npu_matrix_kick_wait();
  201. sum_buffer[i] = matrix_sum;
  202. #else
  203. #if NPU_MEMCPY_EN
  204. u8 mod = ((int)tdnn_buffer_t % 16);
  205. if (mod) { //地址对齐16bytes
  206. memcpy(&npu_matrix_x[matrix_cnt * 256], tdnn_buffer_t, in_dim);
  207. } else {
  208. npu_memcpy_int16((int16_t*)&npu_matrix_x[matrix_cnt * 256], (int16_t*)tdnn_buffer_t, 128);
  209. }
  210. #else
  211. memcpy(&npu_matrix_x[matrix_cnt * 256], tdnn_buffer_t, in_dim);
  212. #endif
  213. matrix_hw(&sum_buffer[matrix_cnt * 16], &npu_matrix_x[matrix_cnt * 256], in_buf, in_dim);
  214. matrix_cnt++;
  215. if (out_dim == i + 1) {
  216. npu_matrix_kick_wait();
  217. tdnn_compute_float_cal(out_buf, tdnn_mean, tdnn_scale, tdnn_var, tdnn_bias, in_scale, matrix_cnt, float_cnt, last_layer);
  218. float_cnt += matrix_cnt;
  219. memset(sum_buffer, 0, 4 * MAX_NPU_MATRIX * 16);
  220. matrix_cnt = 0;
  221. } else {
  222. if (matrix_cnt == MAX_NPU_MATRIX) {
  223. npu_matrix_kick_wait();
  224. tdnn_compute_float_cal(out_buf, tdnn_mean, tdnn_scale, tdnn_var, tdnn_bias, in_scale, matrix_cnt, float_cnt, last_layer);
  225. float_cnt += matrix_cnt;
  226. memset(sum_buffer, 0, 4 * MAX_NPU_MATRIX * 16);
  227. matrix_cnt = 0;
  228. }
  229. }
  230. #endif
  231. }
  232. spiflash_read_wait();
  233. spiflash_unlock();
  234. #if !NPU_CONTIN_CAL_EN
  235. for (i = 0; i < out_dim; i++) {
  236. matrix_sum = sum_buffer[i];
  237. int32_t z = matrix_sum * in_scale + tdnn_bias[i];
  238. if (!last_layer) {
  239. if (z <= 0) {
  240. z = 0;
  241. }
  242. out_buf[i] = (z - tdnn_mean[i]) * tdnn_var[i];
  243. } else {
  244. out_buf[i] = z * tdnn_scale[i];
  245. }
  246. }
  247. #endif
  248. }
  249. ALIGNED(512)
  250. void asr_prefetch_init_do(void)
  251. {
  252. tdnn_pretetch_t *t = &tdnn_pretetch;
  253. memset(t, 0, sizeof(tdnn_pretetch_t));
  254. const tdnn_prenum_t *p = &tdnn_prenum[t->index];
  255. t->load_addr = ASR_BASE_ADDR;
  256. t->cache[0].buf = tdnn_buffer[0];
  257. t->cache[1].buf = tdnn_buffer[1];
  258. //先读
  259. u32 load_len = p->load_frame*p->frame_len;
  260. spiflash_lock();
  261. spiflash_read_kick(t->cache[0].buf, t->load_addr, load_len);
  262. spiflash_read_wait();
  263. t->cache[0].load_frame = p->load_frame;
  264. t->all_frame = p->all_frame - p->load_frame; //已经load了一帧,减掉
  265. t->load_addr += load_len;
  266. spiflash_unlock();
  267. memcpy(tdnn_ram_mean1, tdnn_mean_1, sizeof(tdnn_mean_1));
  268. memcpy(tdnn_ram_bias1, tdnn_bias_1, sizeof(tdnn_bias_1));
  269. memcpy(tdnn_ram_var1, tdnn_var_1, sizeof(tdnn_var_1));
  270. memcpy(tdnn_ram_mean2, tdnn_mean_2, sizeof(tdnn_mean_2));
  271. memcpy(tdnn_ram_bias2, tdnn_bias_2, sizeof(tdnn_bias_2));
  272. memcpy(tdnn_ram_var2, tdnn_var_2, sizeof(tdnn_var_2));
  273. memcpy(tdnn_ram_mean3, tdnn_mean_3, sizeof(tdnn_mean_3));
  274. memcpy(tdnn_ram_bias3, tdnn_bias_3, sizeof(tdnn_bias_3));
  275. memcpy(tdnn_ram_var3, tdnn_var_3, sizeof(tdnn_var_3));
  276. memcpy(tdnn_ram_mean4, tdnn_mean_4, sizeof(tdnn_mean_4));
  277. memcpy(tdnn_ram_bias4, tdnn_bias_4, sizeof(tdnn_bias_4));
  278. memcpy(tdnn_ram_var4, tdnn_var_4, sizeof(tdnn_var_4));
  279. memcpy(tdnn_ram_mean5, tdnn_mean_5, sizeof(tdnn_mean_5));
  280. memcpy(tdnn_ram_bias5, tdnn_bias_5, sizeof(tdnn_bias_5));
  281. memcpy(tdnn_ram_var5, tdnn_var_5, sizeof(tdnn_var_5));
  282. memcpy(tdnn_ram_mean6, tdnn_mean_6, sizeof(tdnn_mean_6));
  283. memcpy(tdnn_ram_bias6, tdnn_bias_6, sizeof(tdnn_bias_6));
  284. memcpy(tdnn_ram_var6, tdnn_var_6, sizeof(tdnn_var_6));
  285. memcpy(tdnn_ram_bias7, tdnn_bias_7, sizeof(tdnn_bias_7));
  286. memcpy(tdnn_ram_scale7, tdnn_scale_7, sizeof(tdnn_scale_7));
  287. }
  288. void asr_prefetch_init(void)
  289. {
  290. memset(asr_offset,0,12);
  291. asr_prefetch_init_do();
  292. asr_prefetch_kisck = 5;
  293. }
  294. #endif