diff --git a/kernel/x86_64/drot_microk_haswell-2.c b/kernel/x86_64/drot_microk_haswell-2.c index cc5949b1ad..17b2798837 100644 --- a/kernel/x86_64/drot_microk_haswell-2.c +++ b/kernel/x86_64/drot_microk_haswell-2.c @@ -12,10 +12,8 @@ static void drot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s) BLASLONG tail_index_16 = n&(~15); __m256d c_256, s_256; - if (n >= 4) { - c_256 = _mm256_set1_pd(c); - s_256 = _mm256_set1_pd(s); - } + c_256 = _mm256_set1_pd(c); + s_256 = _mm256_set1_pd(s); __m256d x0, x1, x2, x3; __m256d y0, y1, y2, y3; @@ -76,10 +74,20 @@ static void drot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s) _mm256_storeu_pd(&y[i], t0); } - for (i = tail_index_4; i < n; ++i) { - FLOAT temp = c * x[i] + s * y[i]; - y[i] = c * y[i] - s * x[i]; - x[i] = temp; + if ((n & 3) > 0) { + const int64_t mask_v[8] = {-1,-1,-1,-1, 0,0,0,0}; + __m256i tail_mask = _mm256_loadu_si256((__m256i*)&mask_v[4 - (n & 3)]); + + x0 = _mm256_maskload_pd(&x[tail_index_4], tail_mask); + y0 = _mm256_maskload_pd(&y[tail_index_4], tail_mask); + + t0 = _mm256_mul_pd(s_256, y0); + t0 = _mm256_fmadd_pd(c_256, x0, t0); + _mm256_maskstore_pd(&x[tail_index_4], tail_mask, t0); + + t0 = _mm256_mul_pd(s_256, x0); + t0 = _mm256_fmsub_pd(c_256, y0, t0); + _mm256_maskstore_pd(&y[tail_index_4], tail_mask, t0); } } #endif diff --git a/kernel/x86_64/srot_microk_haswell-2.c b/kernel/x86_64/srot_microk_haswell-2.c index b5545726eb..75243c7272 100644 --- a/kernel/x86_64/srot_microk_haswell-2.c +++ b/kernel/x86_64/srot_microk_haswell-2.c @@ -13,10 +13,8 @@ static void srot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s) BLASLONG tail_index_32 = n&(~31); __m256 c_256, s_256; - if (n >= 8) { - c_256 = _mm256_set1_ps(c); - s_256 = _mm256_set1_ps(s); - } + c_256 = _mm256_set1_ps(c); + s_256 = _mm256_set1_ps(s); __m256 x0, x1, x2, x3; __m256 y0, y1, y2, y3; @@ -77,10 +75,20 @@ static void srot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s) _mm256_storeu_ps(&y[i], t0); } - for (i = tail_index_8; i < n; ++i) { - FLOAT temp = c * x[i] + s * y[i]; - y[i] = c * y[i] - s * x[i]; - x[i] = temp; + if ((n & 7) > 0) { + const int32_t mask_v[16] = {-1,-1,-1,-1, -1,-1,-1,-1,0,0,0,0,0,0,0,0}; + __m256i tail_mask = _mm256_loadu_si256((__m256i*)&mask_v[8 - (n & 7)]); + + x0 = _mm256_maskload_ps(&x[tail_index_8], tail_mask); + y0 = _mm256_maskload_ps(&y[tail_index_8], tail_mask); + + t0 = _mm256_mul_ps(s_256, y0); + t0 = _mm256_fmadd_ps(c_256, x0, t0); + _mm256_maskstore_ps(&x[tail_index_8], tail_mask, t0); + + t0 = _mm256_mul_ps(s_256, x0); + t0 = _mm256_fmsub_ps(c_256, y0, t0); + _mm256_maskstore_ps(&y[tail_index_8], tail_mask, t0); } } #endif