device_normalization_bwd_data.hpp Source File

device_normalization_bwd_data.hpp Source File#

Composable Kernel: device_normalization_bwd_data.hpp Source File
device_normalization_bwd_data.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <vector>
8
10
11namespace ck {
12namespace tensor_operation {
13namespace device {
14template <typename DYDataType,
15 typename XDataType,
16 typename GammaDataType,
17 typename MeanInvStdDataType,
18 typename DXDataType,
19 index_t Rank,
20 index_t NumReduceDim>
22{
23 virtual std::unique_ptr<BaseArgument>
24 MakeArgumentPointer(const std::vector<index_t> lengths,
25 const std::vector<index_t> dyStrides,
26 const std::vector<index_t> xStrides,
27 const std::vector<index_t> gammaStrides,
28 const std::vector<index_t> meanStrides,
29 const std::vector<index_t> invStdStrides,
30 const std::vector<index_t> dxStrides,
31 const std::vector<index_t> reduceDims,
32 const void* p_dy,
33 const void* p_x,
34 const void* p_gamma,
35 const void* p_mean,
36 const void* p_invStd,
37 void* p_dx) = 0;
38
39 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
40};
41
42template <typename DYDataType,
43 typename XDataType,
44 typename GammaDataType,
45 typename MeanInvStdDataType,
46 typename DXDataType,
47 index_t Rank,
48 index_t NumReduceDim>
50 XDataType,
51 GammaDataType,
52 MeanInvStdDataType,
53 DXDataType,
54 Rank,
55 NumReduceDim>>;
56
57} // namespace device
58} // namespace tensor_operation
59} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
std::unique_ptr< DeviceNormalizationBwdData< DYDataType, XDataType, GammaDataType, MeanInvStdDataType, DXDataType, Rank, NumReduceDim > > DeviceNormalizationBwdDataPtr
Definition device_normalization_bwd_data.hpp:49
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_normalization_bwd_data.hpp:22
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::vector< index_t > lengths, const std::vector< index_t > dyStrides, const std::vector< index_t > xStrides, const std::vector< index_t > gammaStrides, const std::vector< index_t > meanStrides, const std::vector< index_t > invStdStrides, const std::vector< index_t > dxStrides, const std::vector< index_t > reduceDims, const void *p_dy, const void *p_x, const void *p_gamma, const void *p_mean, const void *p_invStd, void *p_dx)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0