device_normalization_bwd_gamma_beta_impl.hpp Source File#
device_normalization_bwd_gamma_beta_impl.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
auto make_tuple_from_array(const std::vector< index_t > &lengths, Number< arraySize >)
Definition device_reduce_common.hpp:65
__global__ void kernel_normalization_bwd_gamma_beta(const GridDesc_M_K dy_grid_desc_m_k, const GridDesc_M_K x_grid_desc_m_k, const GridDesc_M_K mean_grid_desc_m_k, const GridDesc_M_K inv_std_grid_desc_m_k, const GridDesc_M dgamma_grid_desc_m, const GridDesc_M dbeta_grid_desc_m, index_t num_k_block_tile_iteration, const DYDataType *const __restrict__ p_dy_global, const XDataType *const __restrict__ p_x_global, const MeanInvStdDataType *const __restrict__ p_mean_global, const MeanInvStdDataType *const __restrict__ p_inv_std_global, DGammaDataType *const __restrict__ p_dgamma_global, DBetaDataType *const __restrict__ p_dbeta_global)
Definition device_normalization_bwd_gamma_beta_impl.hpp:31
std::pair< long_index_t, long_index_t > get_2d_lengths(const std::vector< index_t > &inLengths)
Definition device_reduce_common.hpp:20
std::vector< index_t > shuffle_tensor_dimensions(const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
Definition device_reduce_common.hpp:75
auto make_tuple_from_array_and_index_seq(const std::vector< index_t > &lengths, Sequence< Ns... >)
Definition device_reduce_common.hpp:59
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
Definition ck/stream_config.hpp:10
Definition gridwise_normalization_bwd_gamma_beta.hpp:37
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition device_base.hpp:197
BaseArgument()=default
BaseInvoker()=default
Definition device_normalization_bwd_gamma_beta.hpp:22
Definition device_normalization_bwd_gamma_beta_impl.hpp:221
GridDesc_M dgamma_grid_desc_m_
Definition device_normalization_bwd_gamma_beta_impl.hpp:297
std::vector< index_t > xStrides_
Definition device_normalization_bwd_gamma_beta_impl.hpp:280
GridDesc_M_K mean_grid_desc_m_k_
Definition device_normalization_bwd_gamma_beta_impl.hpp:293
DBetaDataType * p_dbeta_
Definition device_normalization_bwd_gamma_beta_impl.hpp:276
std::vector< index_t > dyStrides_
Definition device_normalization_bwd_gamma_beta_impl.hpp:279
size_t gridSize_
Definition device_normalization_bwd_gamma_beta_impl.hpp:288
std::vector< index_t > inLengths_
Definition device_normalization_bwd_gamma_beta_impl.hpp:278
const MeanInvStdDataType * p_invStd_
Definition device_normalization_bwd_gamma_beta_impl.hpp:274
GridDesc_M_K inv_std_grid_desc_m_k_
Definition device_normalization_bwd_gamma_beta_impl.hpp:294
std::vector< index_t > outLengths_
Definition device_normalization_bwd_gamma_beta_impl.hpp:283
const MeanInvStdDataType * p_mean_
Definition device_normalization_bwd_gamma_beta_impl.hpp:273
DGammaDataType * p_dgamma_
Definition device_normalization_bwd_gamma_beta_impl.hpp:275
int numBlockTileIteration_
Definition device_normalization_bwd_gamma_beta_impl.hpp:287
index_t MRaw_
Definition device_normalization_bwd_gamma_beta_impl.hpp:300
index_t KRaw_
Definition device_normalization_bwd_gamma_beta_impl.hpp:301
std::vector< index_t > dgammaStrides_
Definition device_normalization_bwd_gamma_beta_impl.hpp:284
const XDataType * p_x_
Definition device_normalization_bwd_gamma_beta_impl.hpp:272
GridDesc_M dbeta_grid_desc_m_
Definition device_normalization_bwd_gamma_beta_impl.hpp:298
std::vector< index_t > dbetaStrides_
Definition device_normalization_bwd_gamma_beta_impl.hpp:285
GridDesc_M_K x_grid_desc_m_k_
Definition device_normalization_bwd_gamma_beta_impl.hpp:292
GridDesc_M_K dy_grid_desc_m_k_
Definition device_normalization_bwd_gamma_beta_impl.hpp:291
Argument(const std::vector< index_t > inLengths, const std::vector< index_t > dyStrides, const std::vector< index_t > xStrides, const std::vector< index_t > meanStrides, const std::vector< index_t > invStdStrides, const std::vector< index_t > outLengths, const std::vector< index_t > dgammaStrides, const std::vector< index_t > dbetaStrides, const std::vector< index_t > reduceDims, const DYDataType *p_dy, const XDataType *p_x, const MeanInvStdDataType *p_mean, const MeanInvStdDataType *p_invStd, DGammaDataType *p_dgamma, DBetaDataType *p_dbeta)
Definition device_normalization_bwd_gamma_beta_impl.hpp:222
std::vector< index_t > invStdStrides_
Definition device_normalization_bwd_gamma_beta_impl.hpp:282
std::vector< index_t > meanStrides_
Definition device_normalization_bwd_gamma_beta_impl.hpp:281
const DYDataType * p_dy_
Definition device_normalization_bwd_gamma_beta_impl.hpp:271
Definition device_normalization_bwd_gamma_beta_impl.hpp:305
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_normalization_bwd_gamma_beta_impl.hpp:306
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_normalization_bwd_gamma_beta_impl.hpp:338
Definition device_normalization_bwd_gamma_beta_impl.hpp:89
bool IsSrcVectorDimSizeValid(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition device_normalization_bwd_gamma_beta_impl.hpp:346
static constexpr index_t K_BlockTileSize
Definition device_normalization_bwd_gamma_beta_impl.hpp:118
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_normalization_bwd_gamma_beta_impl.hpp:392
static auto MakeSrc2dDescriptor(const std::vector< index_t > &inLengths, const std::vector< index_t > &inStrides, int numBlockTileIteration)
Definition device_normalization_bwd_gamma_beta_impl.hpp:123
static constexpr index_t M_BlockTileSize
Definition device_normalization_bwd_gamma_beta_impl.hpp:117
decltype(MakeSrc2dDescriptor({1}, {1}, 1)) GridDesc_M_K
Definition device_normalization_bwd_gamma_beta_impl.hpp:194
bool IsDstVectorSizeValid(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition device_normalization_bwd_gamma_beta_impl.hpp:377
static auto MakeDst1dDescriptor(const std::vector< index_t > &outLengths, const std::vector< index_t > &outStrides)
Definition device_normalization_bwd_gamma_beta_impl.hpp:165
decltype(MakeDst1dDescriptor({1}, {1})) GridDesc_M
Definition device_normalization_bwd_gamma_beta_impl.hpp:195
static constexpr index_t MeanInvStdSrcVectorDim
Definition device_normalization_bwd_gamma_beta_impl.hpp:92
static constexpr index_t DYSrcVectorDim
Definition device_normalization_bwd_gamma_beta_impl.hpp:90
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_normalization_bwd_gamma_beta_impl.hpp:455
std::string GetTypeString() const override
Definition device_normalization_bwd_gamma_beta_impl.hpp:460
static constexpr bool reduceAllDim
Definition device_normalization_bwd_gamma_beta_impl.hpp:120
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::vector< index_t > inLengths, const std::vector< index_t > dyStrides, const std::vector< index_t > xStrides, const std::vector< index_t > meanStrides, const std::vector< index_t > invStdStrides, const std::vector< index_t > outLengths, const std::vector< index_t > dgammaStrides, const std::vector< index_t > dbetaStrides, const std::vector< index_t > reduceDims, const void *p_dy, const void *p_x, const void *p_mean, const void *p_invStd, void *p_dgamma, void *p_dbeta) override
Definition device_normalization_bwd_gamma_beta_impl.hpp:414
static constexpr index_t XSrcVectorDim
Definition device_normalization_bwd_gamma_beta_impl.hpp:91
ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaImpl::GridwiseNormalizationBwdGammaBeta
GridwiseNormalizationBwdGammaBeta_mk_to_k< DYDataType, XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType, GridDesc_M_K, GridDesc_M, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, DYSrcVectorDim, DYSrcVectorSize, XSrcVectorDim, XSrcVectorSize, MeanInvStdSrcVectorDim, MeanInvStdSrcVectorSize, DGammaDstVectorSize, DBetaDstVectorSize > GridwiseNormalizationBwdGammaBeta
Definition device_normalization_bwd_gamma_beta_impl.hpp:197
static constexpr index_t NumInvariantDim
Definition device_normalization_bwd_gamma_beta_impl.hpp:116