73 template <
typename Shape>
    84     "Mma_HFMA2 requires the M dimension to be divisible by 2."   121       arch::OpMultiplyAdd>;
   123     Array<half_t, 2> *ptr_D = 
reinterpret_cast<Array<half_t, 2> *
>(&D);
   124     Array<half_t, 2> 
const *ptr_A = 
reinterpret_cast<Array<half_t, 2> 
const *
>(&A);
   125     Array<half_t, 1> 
const *ptr_B = 
reinterpret_cast<Array<half_t, 1> 
const *
>(&B);
   130     for(
auto k=0; k <  Shape::kK / Mma::Shape::kK; k++){
   133       for(
auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
   136         for(
auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
   138             Array<half_t, 2> tmp;
   139             Array<half_t, 2> *ptr_tmp = &tmp;
   140             ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m];
   144                 ptr_A[k*Shape::kM/2 + m],
   145                 ptr_B[n*Shape::kK + k],
   148             ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0];
   159 template <
typename Shape>
   170     "Mma_HFMA2 requires the N dimension to be divisible by 2."   207       arch::OpMultiplyAdd>;
   209     Array<half_t, 2> *ptr_D = 
reinterpret_cast<Array<half_t, 2> *
>(&D);
   210     Array<half_t, 1> 
const *ptr_A = 
reinterpret_cast<Array<half_t, 1> 
const *
>(&A);
   211     Array<half_t, 2> 
const *ptr_B = 
reinterpret_cast<Array<half_t, 2> 
const *
>(&B);
   216     for(
auto k=0; k <  Shape::kK / Mma::Shape::kK; k++){
   219         for(
auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
   222           for(
auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
   224             Array<half_t, 2> tmp;
   225             Array<half_t, 2> *ptr_tmp = &tmp;
   226             ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n];
   228             Array<half_t, 2> tmp_B;
   229             tmp_B[0] = ptr_B->at(2*n*Shape::kK + k);
   230             tmp_B[1] = ptr_B->at((2*n+1)*Shape::kK + k);
   234                 ptr_A[k*Shape::kM + m],
   238             ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0];
   250 template <
typename Shape>
   261     "Mma_HFMA2 requires the GEMM M dimension to be divisible by 2."   297       arch::OpMultiplyAdd>;
   299     Array<half_t, 2> *ptr_D = 
reinterpret_cast<Array<half_t, 2> *
>(&D);
   300     Array<half_t, 2> 
const *ptr_A = 
reinterpret_cast<Array<half_t, 2> 
const *
>(&A);
   301     Array<half_t, 1> 
const *ptr_B = 
reinterpret_cast<Array<half_t, 1> 
const *
>(&B);
   306     for (
int k = 0; k < Shape::kK / Mma::Shape::kK; ++k) {
   309         for (
int m = 0; m < Shape::kM / Mma::Shape::kM; ++m) {
   312           for (
int n = 0; n < Shape::kN / Mma::Shape::kN; ++n) {
   314           Array<half_t, 2> tmp;
   315           Array<half_t, 2> *ptr_tmp = &tmp;
   317           ptr_tmp[0] = ptr_D[m + n * Shape::kM/2];
   321             ptr_A[m + k * Shape::kM/2],
   322             ptr_B[k * Shape::kN + n],
   325           ptr_D[m + n * Shape::kM/2] = ptr_tmp[0];
   336 template <
typename Shape>
   347     "Mma_HFMA2 requires the N dimension to be divisible by 2."   384       arch::OpMultiplyAdd>;
   386     Array<half_t, 2> *ptr_D = 
reinterpret_cast<Array<half_t, 2> *
>(&D);
   387     Array<half_t, 1> 
const *ptr_A = 
reinterpret_cast<Array<half_t, 1> 
const *
>(&A);
   388     Array<half_t, 2> 
const *ptr_B = 
reinterpret_cast<Array<half_t, 2> 
const *
>(&B);
   393     for(
auto k=0; k <  Shape::kK / Mma::Shape::kK; k++){
   396         for(
auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
   399           for(
auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
   401             Array<half_t, 2> tmp;
   402             Array<half_t, 2> *ptr_tmp = &tmp;
   403             ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n];
   407                 ptr_A[k*Shape::kM + m],
   408                 ptr_B[k*Shape::kN/2 + n],
   411             ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0];
   423 template <
typename Shape>
   434     "Mma_HFMA2 requires the M dimension to be divisible by 2."   471       arch::OpMultiplyAdd>;
   473     Array<half_t, 2> *ptr_D = 
reinterpret_cast<Array<half_t, 2> *
>(&D);
   474     Array<half_t, 2> 
const *ptr_A = 
reinterpret_cast<Array<half_t, 2> 
const *
>(&A);
   475     Array<half_t, 1> 
const *ptr_B = 
reinterpret_cast<Array<half_t, 1> 
const *
>(&B);
   480     for(
auto k=0; k <  Shape::kK / Mma::Shape::kK; k++){
   483       for(
auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
   486         for(
auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
   488             Array<half_t, 2> tmp;
   489             Array<half_t, 2> *ptr_tmp = &tmp;
   490             ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m];
   492             Array<half_t, 2> tmp_A;
   493             tmp_A[0] = ptr_A->at(2*m*Shape::kK + k);
   494             tmp_A[1] = ptr_A->at((2*m+1)*Shape::kK + k);
   499                 ptr_B[n*Shape::kK + k],
   502             ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0];
   513 template <
typename Shape>
   524     "Mma_HFMA2 requires the N dimension to be divisible by 2."   561       arch::OpMultiplyAdd>;
   563     Array<half_t, 2> *ptr_D = 
reinterpret_cast<Array<half_t, 2> *
>(&D);
   564     Array<half_t, 1> 
const *ptr_A = 
reinterpret_cast<Array<half_t, 1> 
const *
>(&A);
   565     Array<half_t, 2> 
const *ptr_B = 
reinterpret_cast<Array<half_t, 2> 
const *
>(&B);
   570     for(
auto k=0; k <  Shape::kK / Mma::Shape::kK; k++){
   573         for(
auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
   576           for(
auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
   578             Array<half_t, 2> tmp;
   579             Array<half_t, 2> *ptr_tmp = &tmp;
   580             ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n];
   582             Array<half_t, 2> tmp_B;
   583             tmp_B[0] = ptr_B->at(2*n*Shape::kK + k);
   584             tmp_B[1] = ptr_B->at((2*n+1)*Shape::kK + k);
   588                 ptr_A[m*Shape::kK + k],
   592             ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0];
   603 template <
typename Shape>
   614     "Mma_HFMA2 requires the M dimension to be divisible by 2."   651       arch::OpMultiplyAdd>;
   653     Array<half_t, 2> *ptr_D = 
reinterpret_cast<Array<half_t, 2> *
>(&D);
   654     Array<half_t, 2> 
const *ptr_A = 
reinterpret_cast<Array<half_t, 2> 
const *
>(&A);
   655     Array<half_t, 1> 
const *ptr_B = 
reinterpret_cast<Array<half_t, 1> 
const *
>(&B);
   660     for(
auto k=0; k <  Shape::kK / Mma::Shape::kK; k++){
   663       for(
auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
   666         for(
auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
   668             Array<half_t, 2> tmp;
   669             Array<half_t, 2> *ptr_tmp = &tmp;
   670             ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m];
   672             Array<half_t, 2> tmp_A;
   673             tmp_A[0] = ptr_A->at(2*m*Shape::kK + k);
   674             tmp_A[1] = ptr_A->at((2*m+1)*Shape::kK + k);
   679                 ptr_B[k*Shape::kN + n],
   682             ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0];
   694 template <
typename Shape>
   705     "Mma_HFMA2 requires the N dimension to be divisible by 2."   742       arch::OpMultiplyAdd>;
   744     Array<half_t, 2> *ptr_D = 
reinterpret_cast<Array<half_t, 2> *
>(&D);
   745     Array<half_t, 1> 
const *ptr_A = 
reinterpret_cast<Array<half_t, 1> 
const *
>(&A);
   746     Array<half_t, 2> 
const *ptr_B = 
reinterpret_cast<Array<half_t, 2> 
const *
>(&B);
   751     for(
auto k=0; k <  Shape::kK / Mma::Shape::kK; k++){
   754         for(
auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
   757           for(
auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
   759             Array<half_t, 2> tmp;
   760             Array<half_t, 2> *ptr_tmp = &tmp;
   761             ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n];
   765                 ptr_A[m*Shape::kK + k],
   766                 ptr_B[k*Shape::kN/2 + n],
   769             ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0];
   780 template <
typename Shape, 
typename LayoutA, 
typename LayoutB>
   791     "Mma_HFMA2 requires the K dimension to be divisible by 2."   821     Array<half_t, 1> *ptr_D = 
reinterpret_cast<Array<half_t, 1> *
>(&D);
   822     Array<half_t, 2> 
const *ptr_A = 
reinterpret_cast<Array<half_t, 2> 
const *
>(&A);
   823     Array<half_t, 2> 
const *ptr_B = 
reinterpret_cast<Array<half_t, 2> 
const *
>(&B);
   835         Array<half_t, 2> tmp_C;
   837         Array<half_t, 1> *ptr_tmp_C = 
reinterpret_cast<Array<half_t, 1> *
>(&tmp_C);
   838         ptr_tmp_C[0] = ptr_D[n*Shape::kM + m];
   842           tmp_C = 
mac(ptr_A[m*Shape::kK/2 + k], ptr_B[n*Shape::kK/2 + k], tmp_C);
   845         Array<half_t, 1> res;
   846         Array<half_t, 1> *ptr_res = &res;
   849         ptr_D[m*Shape::kN + n] = ptr_res[0];
   859 template <
typename Shape, 
typename LayoutA, 
typename LayoutB>
   870     "Mma_HFMA2 requires the K dimension to be divisible by 2."   900     Array<half_t, 1> *ptr_D = 
reinterpret_cast<Array<half_t, 1> *
>(&D);
   901     Array<half_t, 2> 
const *ptr_A = 
reinterpret_cast<Array<half_t, 2> 
const *
>(&A);
   902     Array<half_t, 2> 
const *ptr_B = 
reinterpret_cast<Array<half_t, 2> 
const *
>(&B);
   914         Array<half_t, 2> tmp_C;
   916         Array<half_t, 1> *ptr_tmp_C = 
reinterpret_cast<Array<half_t, 1> *
>(&tmp_C);
   917         ptr_tmp_C[0] = ptr_D[n*Shape::kM + m];
   922           tmp_C = 
mac(ptr_A[m*Shape::kK/2 + k], ptr_B[n*Shape::kK/2 + k], tmp_C);
   926         Array<half_t, 1> res;
   927         Array<half_t, 1> *ptr_res = &res;
   930         ptr_D[n*Shape::kM + m] = ptr_res[0];
   943   typename Shape_, 
typename LayoutA, 
typename LayoutB, 
typename LayoutC
   997     constexpr bool m_mod2 = !(Shape::kM % 2);
   998     constexpr bool n_mod2 = !(Shape::kN % 2);
   999     constexpr bool k_mod2 = !(Shape::kK % 2);
  1007     constexpr bool use_outer_prod = (c_column_major && m_mod2) || (c_row_major && n_mod2);
  1008     constexpr bool use_inner_prod = (a_row_major && b_column_major && k_mod2) || (Shape::kM==1 && Shape::kN==1 && k_mod2);
  1009     constexpr bool use_optimized =  (use_outer_prod || use_inner_prod);
  1032     static bool const kIsConventionalLayout =
  1038     static bool const value = kIsConventionalLayout;
  1059   arch::OpMultiplyAdd,
  1067   using LayoutA = LayoutA_;
  1069   using LayoutB = LayoutB_;
  1082     arch::OpMultiplyAdd,
 Fused multiply-add. 
Definition: functional.h:92
Determines whether to enable thread::Gemm<> specializations compatible with SM50. ...
Definition: gemm/thread/mma_sm60.h:1030
Array< half_t, Shape::kMN > FragmentC
C operand storage. 
Definition: gemm/thread/mma_sm60.h:801
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
Array< half_t, Shape::kMN > FragmentC
C operand storage. 
Definition: gemm/thread/mma_sm60.h:271
Definition: aligned_buffer.h:35
Defines a structure containing strides, bounds, and a pointer to tensor data. 
Array< half_t, Shape::kMN > FragmentC
C operand storage. 
Definition: gemm/thread/mma_sm60.h:94
Array< half_t, Shape::kMK > FragmentA
A operand storage. 
Definition: gemm/thread/mma_sm60.h:528
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C. 
Definition: gemm/thread/mma_sm60.h:809
Structure to compute the matrix product for HFMA. 
Definition: gemm/thread/mma_sm60.h:66
Array< ElementC, Shape::kMN > FragmentC
Definition: gemm/thread/mma_sm60.h:1087
Array< half_t, Shape::kKN > FragmentB
B operand storage. 
Definition: gemm/thread/mma_sm60.h:441
IEEE half-precision floating-point type. 
Definition: half.h:126
Defines common types used for all GEMM-like operators. 
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Definition: gemm/thread/mma_sm60.h:1090
Array< half_t, Shape::kMN > FragmentC
C operand storage. 
Definition: gemm/thread/mma_sm60.h:444
Array< half_t, Shape::kMK > FragmentA
A operand storage. 
Definition: gemm/thread/mma_sm60.h:438
Array< half_t, Shape::kMN > FragmentC
C operand storage. 
Definition: gemm/thread/mma_sm60.h:357
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C. 
Definition: gemm/thread/mma_sm60.h:102
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C. 
Definition: gemm/thread/mma_sm60.h:723
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C. 
Definition: gemm/thread/mma_sm60.h:632
Array< half_t, Shape::kMK > FragmentA
A operand storage. 
Definition: gemm/thread/mma_sm60.h:174
Array< half_t, Shape::kKN > FragmentB
B operand storage. 
Definition: gemm/thread/mma_sm60.h:712
Array< half_t, Shape::kMN > FragmentC
C operand storage. 
Definition: gemm/thread/mma_sm60.h:624
Mapping function for column-major matrices. 
Definition: layout/matrix.h:142
arch::OpMultiplyAdd Operator
Definition: gemm/thread/mma_sm60.h:1072
static int const kK
Definition: include/cutlass/gemm/gemm.h:60
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Array< half_t, Shape::kKN > FragmentB
B operand storage. 
Definition: gemm/thread/mma_sm60.h:177
Array< ElementB, Shape::kKN > FragmentB
B operand storage. 
Definition: gemm/thread/mma_sm60.h:975
Array< half_t, Shape::kKN > FragmentB
B operand storage. 
Definition: gemm/thread/mma_sm60.h:91
Array< ElementB, Shape::kKN > FragmentB
Definition: gemm/thread/mma_sm60.h:1086
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C. 
Definition: gemm/thread/mma_sm60.h:365
arch::OpMultiplyAdd Operator
Underlying mathematical operator. 
Definition: gemm/thread/mma_sm60.h:969
Array< half_t, Shape::kMK > FragmentA
A operand storage. 
Definition: gemm/thread/mma_sm60.h:709
Defines transposes of matrix layouts. 
Definition: layout/matrix.h:921
Array< half_t, Shape::kKN > FragmentB
B operand storage. 
Definition: gemm/thread/mma_sm60.h:531
Gemplate that handles all packed matrix layouts. 
Definition: gemm/thread/mma_sm50.h:65
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C. 
Definition: gemm/thread/mma_sm60.h:888
Defines basic thread level reduction with specializations for Array<T, N>. 
Array< ElementC, Shape::kMN > FragmentC
C operand storage. 
Definition: gemm/thread/mma_sm60.h:978
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C. 
Definition: gemm/thread/mma_sm60.h:188
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Templates exposing architecture support for warp-level multiply-add operations. 
Shape of a matrix multiply-add operation. 
Definition: include/cutlass/gemm/gemm.h:57
Array< half_t, Shape::kMK > FragmentA
A operand storage. 
Definition: gemm/thread/mma_sm60.h:265
Array< half_t, Shape::kMK > FragmentA
A operand storage. 
Definition: gemm/thread/mma_sm60.h:88
Array< half_t, Shape::kMN > FragmentC
C operand storage. 
Definition: gemm/thread/mma_sm60.h:880
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C. 
Definition: gemm/thread/mma_sm60.h:986
Array< half_t, Shape::kMK > FragmentA
A operand storage. 
Definition: gemm/thread/mma_sm60.h:795
Array< half_t, Shape::kMK > FragmentA
A operand storage. 
Definition: gemm/thread/mma_sm60.h:351
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C. 
Definition: gemm/thread/mma_sm60.h:452
Mapping function for row-major matrices. 
Definition: layout/matrix.h:50
Array< half_t, Shape::kKN > FragmentB
B operand storage. 
Definition: gemm/thread/mma_sm60.h:621
Shape_ Shape
Size of the Gemm problem - concept: gemm::GemmShape<> 
Definition: gemm/thread/mma_sm60.h:957
Structure to compute the matrix product. 
Definition: gemm/thread/mma.h:66
Array< half_t, Shape::kKN > FragmentB
B operand storage. 
Definition: gemm/thread/mma_sm60.h:877
Array< half_t, Shape::kMN > FragmentC
C operand storage. 
Definition: gemm/thread/mma_sm60.h:715
Defines layout functions used by TensorRef and derived classes. 
Array< half_t, Shape::kMN > FragmentC
C operand storage. 
Definition: gemm/thread/mma_sm60.h:534
Array< half_t, Shape::kMN > FragmentC
C operand storage. 
Definition: gemm/thread/mma_sm60.h:180
Array< half_t, Shape::kKN > FragmentB
B operand storage. 
Definition: gemm/thread/mma_sm60.h:798
Array< half_t, Shape::kMK > FragmentA
A operand storage. 
Definition: gemm/thread/mma_sm60.h:618
Array< half_t, Shape::kKN > FragmentB
B operand storage. 
Definition: gemm/thread/mma_sm60.h:268
Matrix multiply-add operation. 
Definition: arch/mma.h:92
Array< ElementA, Shape::kMK > FragmentA
A operand storage. 
Definition: gemm/thread/mma_sm60.h:972
Array< half_t, Shape::kKN > FragmentB
B operand storage. 
Definition: gemm/thread/mma_sm60.h:354
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C. 
Definition: gemm/thread/mma_sm60.h:542
Basic include for CUTLASS. 
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C. 
Definition: gemm/thread/mma_sm60.h:279
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
Array< ElementA, Shape::kMK > FragmentA
Definition: gemm/thread/mma_sm60.h:1085
Structure to compute the thread level reduction. 
Definition: reduce.h:43
CUTLASS_HOST_DEVICE Array< T, N > mac(Array< T, N > const &a, Array< T, N > const &b, Array< T, N > const &c)
Definition: simd.h:84
Shape_ Shape
Definition: gemm/thread/mma_sm60.h:1065
Array< half_t, Shape::kMK > FragmentA
A operand storage. 
Definition: gemm/thread/mma_sm60.h:874
static int const kN
Definition: include/cutlass/gemm/gemm.h:59