use std::cmp::min;
use std::mem::size_of;
use std::ptr::copy_nonoverlapping;
use aligned_alloc::Alloc;
use util::range_chunk;
use util::round_up_to;
use kernel::ConstNum;
use kernel::Element;
use kernel::GemmKernel;
use kernel::GemmSelect;
use sgemm_kernel;
use dgemm_kernel;
use rawpointer::PointerExt;
pub unsafe fn sgemm(
m: usize, k: usize, n: usize,
alpha: f32,
a: *const f32, rsa: isize, csa: isize,
b: *const f32, rsb: isize, csb: isize,
beta: f32,
c: *mut f32, rsc: isize, csc: isize)
{
sgemm_kernel::detect(GemmParameters { m, k, n,
alpha,
a, rsa, csa,
b, rsb, csb,
beta,
c, rsc, csc})
}
pub unsafe fn dgemm(
m: usize, k: usize, n: usize,
alpha: f64,
a: *const f64, rsa: isize, csa: isize,
b: *const f64, rsb: isize, csb: isize,
beta: f64,
c: *mut f64, rsc: isize, csc: isize)
{
dgemm_kernel::detect(GemmParameters { m, k, n,
alpha,
a, rsa, csa,
b, rsb, csb,
beta,
c, rsc, csc})
}
struct GemmParameters<T> {
m: usize, k: usize, n: usize,
alpha: T,
a: *const T, rsa: isize, csa: isize,
beta: T,
b: *const T, rsb: isize, csb: isize,
c: *mut T, rsc: isize, csc: isize,
}
impl<T> GemmSelect<T> for GemmParameters<T> {
fn select<K>(self, _kernel: K)
where K: GemmKernel<Elem=T>,
T: Element,
{
let GemmParameters {
m, k, n,
alpha,
a, rsa, csa,
b, rsb, csb,
beta,
c, rsc, csc} = self;
unsafe {
gemm_loop::<K>(
m, k, n,
alpha,
a, rsa, csa,
b, rsb, csb,
beta,
c, rsc, csc)
}
}
}
#[inline(always)]
fn ensure_kernel_params<K>()
where K: GemmKernel
{
let mr = K::MR;
let nr = K::NR;
assert!(mr > 0 && mr <= 8);
assert!(nr > 0 && nr <= 8);
assert!(mr * nr * size_of::<K::Elem>() <= 8 * 4 * 8);
assert!(K::align_to() <= 32);
let max_align = size_of::<K::Elem>() * min(mr, nr);
assert!(K::align_to() <= max_align);
}
unsafe fn gemm_loop<K>(
m: usize, k: usize, n: usize,
alpha: K::Elem,
a: *const K::Elem, rsa: isize, csa: isize,
b: *const K::Elem, rsb: isize, csb: isize,
beta: K::Elem,
c: *mut K::Elem, rsc: isize, csc: isize)
where K: GemmKernel
{
debug_assert!(m <= 1 || n == 0 || rsc != 0);
debug_assert!(m == 0 || n <= 1 || csc != 0);
if m == 0 || k == 0 || n == 0 {
return c_to_beta_c(m, n, beta, c, rsc, csc);
}
let knc = K::nc();
let kkc = K::kc();
let kmc = K::mc();
ensure_kernel_params::<K>();
let (mut packing_buffer, bp_offset) = make_packing_buffer::<K>(m, k, n);
let app = packing_buffer.ptr_mut();
let bpp = app.add(bp_offset);
for (l5, nc) in range_chunk(n, knc) {
dprint!("LOOP 5, {}, nc={}", l5, nc);
let b = b.stride_offset(csb, knc * l5);
let c = c.stride_offset(csc, knc * l5);
for (l4, kc) in range_chunk(k, kkc) {
dprint!("LOOP 4, {}, kc={}", l4, kc);
let b = b.stride_offset(rsb, kkc * l4);
let a = a.stride_offset(csa, kkc * l4);
pack::<K::NRTy, _>(kc, nc, bpp, b, csb, rsb);
for (l3, mc) in range_chunk(m, kmc) {
dprint!("LOOP 3, {}, mc={}", l3, mc);
let a = a.stride_offset(rsa, kmc * l3);
let c = c.stride_offset(rsc, kmc * l3);
pack::<K::MRTy, _>(kc, mc, app, a, rsa, csa);
let betap = if l4 == 0 { beta } else { <_>::one() };
gemm_packed::<K>(nc, kc, mc,
alpha,
app, bpp,
betap,
c, rsc, csc);
}
}
}
}
unsafe fn gemm_packed<K>(nc: usize, kc: usize, mc: usize,
alpha: K::Elem,
app: *const K::Elem, bpp: *const K::Elem,
beta: K::Elem,
c: *mut K::Elem, rsc: isize, csc: isize)
where K: GemmKernel,
{
let mr = K::MR;
let nr = K::NR;
assert!(mr * nr * size_of::<K::Elem>() <= 256 && K::align_to() <= 32);
let mut mask_buf = [0u8; 256 + 31];
let mask_ptr = align_ptr(32, mask_buf.as_mut_ptr()) as *mut K::Elem;
for (l2, nr_) in range_chunk(nc, nr) {
let bpp = bpp.stride_offset(1, kc * nr * l2);
let c = c.stride_offset(csc, nr * l2);
for (l1, mr_) in range_chunk(mc, mr) {
let app = app.stride_offset(1, kc * mr * l1);
let c = c.stride_offset(rsc, mr * l1);
if K::always_masked() || nr_ < nr || mr_ < mr {
masked_kernel::<_, K>(kc, alpha, &*app, &*bpp,
beta, &mut *c, rsc, csc,
mr_, nr_, mask_ptr);
continue;
} else {
K::kernel(kc, alpha, app, bpp, beta, c, rsc, csc);
}
}
}
}
unsafe fn make_packing_buffer<K>(m: usize, k: usize, n: usize) -> (Alloc<K::Elem>, usize)
where K: GemmKernel,
{
let m = min(m, K::mc());
let k = min(k, K::kc());
let n = min(n, K::nc());
let apack_size = k * round_up_to(m, K::MR);
let bpack_size = k * round_up_to(n, K::NR);
let nelem = apack_size + bpack_size;
dprint!("packed nelem={}, apack={}, bpack={},
m={} k={} n={}",
nelem, apack_size, bpack_size,
m,k,n);
(Alloc::new(nelem, K::align_to()), apack_size)
}
unsafe fn align_ptr<U>(align_to: usize, mut ptr: *mut U) -> *mut U {
if align_to != 0 {
let cur_align = ptr as usize % align_to;
if cur_align != 0 {
ptr = ptr.offset(((align_to - cur_align) / size_of::<U>()) as isize);
}
}
ptr
}
unsafe fn pack<MR, T>(kc: usize, mc: usize, pack: *mut T,
a: *const T, rsa: isize, csa: isize)
where T: Element,
MR: ConstNum,
{
let mr = MR::VALUE;
let mut p = 0;
if rsa == 1 {
for ir in 0..mc/mr {
let row_offset = ir * mr;
for j in 0..kc {
let a_row = a.stride_offset(rsa, row_offset)
.stride_offset(csa, j);
copy_nonoverlapping(a_row, pack.add(p), mr);
p += mr;
}
}
} else {
for ir in 0..mc/mr {
let row_offset = ir * mr;
for j in 0..kc {
for i in 0..mr {
let a_elt = a.stride_offset(rsa, i + row_offset)
.stride_offset(csa, j);
copy_nonoverlapping(a_elt, pack.add(p), 1);
p += 1;
}
}
}
}
let zero = <_>::zero();
let rest = mc % mr;
if rest > 0 {
let row_offset = (mc/mr) * mr;
for j in 0..kc {
for i in 0..mr {
if i < rest {
let a_elt = a.stride_offset(rsa, i + row_offset)
.stride_offset(csa, j);
copy_nonoverlapping(a_elt, pack.add(p), 1);
} else {
*pack.add(p) = zero;
}
p += 1;
}
}
}
}
#[inline(never)]
unsafe fn masked_kernel<T, K>(k: usize, alpha: T,
a: *const T,
b: *const T,
beta: T,
c: *mut T, rsc: isize, csc: isize,
rows: usize, cols: usize,
mask_buf: *mut T)
where K: GemmKernel<Elem=T>, T: Element,
{
let mr = K::MR;
let nr = K::NR;
K::kernel(k, T::one(), a, b, T::zero(), mask_buf, 1, mr as isize);
let mut ab = mask_buf;
for j in 0..nr {
for i in 0..mr {
if i < rows && j < cols {
let cptr = c.stride_offset(rsc, i)
.stride_offset(csc, j);
if beta.is_zero() {
*cptr = T::zero();
} else {
(*cptr).scale_by(beta);
}
(*cptr).scaled_add(alpha, *ab);
}
ab.inc();
}
}
}
unsafe fn c_to_beta_c<T>(m: usize, n: usize, beta: T,
c: *mut T, rsc: isize, csc: isize)
where T: Element
{
for i in 0..m {
for j in 0..n {
let cptr = c.stride_offset(rsc, i)
.stride_offset(csc, j);
if beta.is_zero() {
*cptr = T::zero();
} else {
(*cptr).scale_by(beta);
}
}
}
}