device_grouped_contraction_multiple_d.hpp Source File

device_grouped_contraction_multiple_d.hpp Source File#

Composable Kernel: device_grouped_contraction_multiple_d.hpp Source File
device_grouped_contraction_multiple_d.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <vector>
8
10
11namespace ck {
12namespace tensor_operation {
13namespace device {
14
15template <index_t NumDTensor>
17{
18 std::vector<index_t> a_ms_ks_lengths;
19 std::vector<index_t> a_ms_ks_strides;
20
21 std::vector<index_t> b_ns_ks_lengths;
22 std::vector<index_t> b_ns_ks_strides;
23
24 std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths;
25 std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides;
26
27 std::vector<index_t> e_ms_ns_lengths;
28 std::vector<index_t> e_ms_ns_strides;
29};
30
31// Tensor Contraction:
32// input : A
33// input : B
34// input : D0, D1, ...
35// output : E
36// C = a_op(A) * b_op(B)
37// E = cde_op(C, D0, D1, ...)
38// Assume:
39// A[M0, M1, M2, ..., K0, K1, K2, ...]
40// B[N0, N1, N2, ..., K0, K1, K2, ...]
41// D[M0, M1, M2, ..., N0, N1, N2, ...]
42// E[M0, M1, M2, ..., N0, N1, N2, ...]
43template <index_t NumDimM,
44 index_t NumDimN,
45 index_t NumDimK,
46 typename ADataType,
47 typename BDataType,
48 typename DsDataType,
49 typename EDataType,
50 typename AElementwiseOperation,
51 typename BElementwiseOperation,
52 typename CDEElementwiseOperation>
54{
55 static constexpr index_t NumDTensor = DsDataType::Size();
56
57 virtual std::unique_ptr<BaseArgument>
58 MakeArgumentPointer(std::vector<const void*> p_a_vec,
59 std::vector<const void*> p_b_vec,
60 std::vector<std::array<const void*, NumDTensor>> p_ds_vec,
61 std::vector<void*> p_e_vec,
62 std::vector<ContractionDesc<NumDTensor>> contraction_descs,
63 AElementwiseOperation a_element_op,
64 BElementwiseOperation b_element_op,
65 CDEElementwiseOperation cde_element_op) = 0;
66
67 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
68};
69
70} // namespace device
71} // namespace tensor_operation
72} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_grouped_contraction_multiple_d.hpp:17
std::array< std::vector< index_t >, NumDTensor > ds_ms_ns_lengths
Definition device_grouped_contraction_multiple_d.hpp:24
std::vector< index_t > b_ns_ks_lengths
Definition device_grouped_contraction_multiple_d.hpp:21
std::vector< index_t > a_ms_ks_strides
Definition device_grouped_contraction_multiple_d.hpp:19
std::vector< index_t > a_ms_ks_lengths
Definition device_grouped_contraction_multiple_d.hpp:18
std::array< std::vector< index_t >, NumDTensor > ds_ms_ns_strides
Definition device_grouped_contraction_multiple_d.hpp:25
std::vector< index_t > e_ms_ns_strides
Definition device_grouped_contraction_multiple_d.hpp:28
std::vector< index_t > e_ms_ns_lengths
Definition device_grouped_contraction_multiple_d.hpp:27
std::vector< index_t > b_ns_ks_strides
Definition device_grouped_contraction_multiple_d.hpp:22
Definition device_grouped_contraction_multiple_d.hpp:54
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > p_a_vec, std::vector< const void * > p_b_vec, std::vector< std::array< const void *, NumDTensor > > p_ds_vec, std::vector< void * > p_e_vec, std::vector< ContractionDesc< NumDTensor > > contraction_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
static constexpr index_t NumDTensor
Definition device_grouped_contraction_multiple_d.hpp:55