device_grouped_conv_bwd_data_multiple_d.hpp Source File

device_grouped_conv_bwd_data_multiple_d.hpp Source File#

Composable Kernel: device_grouped_conv_bwd_data_multiple_d.hpp Source File
device_grouped_conv_bwd_data_multiple_d.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <array>
7
9
10namespace ck {
11namespace tensor_operation {
12namespace device {
13
14// Conv backward data multiple D:
15// input : output image A[G, N, K, Ho, Wo]
16// input : weight B[G, K, C, Y, X],
17// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
18// output : input image E[G, N, C, Hi, Wi],
19// C = a_op(A) * b_op(B)
20// E = cde_op(C, D0, D1, ...)
21template <ck::index_t NDimSpatial,
22 typename ALayout,
23 typename BLayout,
24 typename DsLayout,
25 typename ELayout,
26 typename ADataType,
27 typename BDataType,
28 typename DsDataType,
29 typename EDataType,
30 typename AElementwiseOperation,
31 typename BElementwiseOperation,
32 typename CDEElementwiseOperation,
33 typename AComputeType = ADataType,
34 typename BComputeType = AComputeType>
36{
37 static constexpr index_t NumDTensor = DsDataType::Size();
38
39 static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor");
40
41 virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
42 const void* p_a, // output image
43 const void* p_b, // weight
44 const std::array<const void*, NumDTensor>& p_ds, // bias
45 void* p_e, // input image
46 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output image
47 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides, // output image
48 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
49 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
50 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
51 ds_g_n_k_wos_lengths, // bias
52 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
53 ds_g_n_k_wos_strides, // bias
54 const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
55 const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides, // input image
56 const std::array<index_t, NDimSpatial>& conv_filter_strides,
57 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
58 const std::array<index_t, NDimSpatial>& input_left_pads,
59 const std::array<index_t, NDimSpatial>& input_right_pads,
60 const AElementwiseOperation& a_element_op,
61 const BElementwiseOperation& b_element_op,
62 const CDEElementwiseOperation& cde_element_op,
63 const ck::index_t split_k = 1) = 0;
64
65 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
66};
67
68} // namespace device
69} // namespace tensor_operation
70} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_grouped_conv_bwd_data_multiple_d.hpp:36
static constexpr index_t NumDTensor
Definition device_grouped_conv_bwd_data_multiple_d.hpp:37
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const ck::index_t split_k=1)=0