Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TPrecision > Class Template Referenceexport
Inheritance diagram for Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TPrecision >:
Collaboration diagram for Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TPrecision >:

Public Types

using CudaExecutionContext = ExecutionContext<DeviceType::Cuda>
using MR = CudaDeviceMemoryResource
using NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TPrecision>::device_type
using TensorType = Tensor<TPrecision, MR>
using UnaryOperationBase = UnaryOperation<DeviceType::Cuda, TPrecision>
Public Types inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision >
using MR
using TensorInputType
using TensorOutputType
Public Types inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput >
using DataTypeTraits

Public Member Functions

 CudaSwigluOp (IExecutionContext *context, const SwigluConfig &config)
void backward (const ITensor &input, const ITensor &output_gradient, ITensor &input_gradient) const override
 Backward pass: compute gradient wrt input given output gradient.
void forward (const ITensor &input, ITensor &output) const override
 Forward pass: compute output = f(input).
std::string getName () const override
 Human-readable operation name.
OperationType getOperationType () const override
 Operation type identifier.
Public Member Functions inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision >
virtual ~UnaryOperation ()=default
Public Member Functions inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput >
virtual ~Operation ()=default
virtual void build (const BuildContext &build_context)
 Prepare the operation for a concrete input shape.
virtual void clearGradients () noexcept
 Clear any cached gradient pointers held by the operation.
virtual TensorDataType getDataType () const
 Tensor data type for this operation.
virtual DeviceType getDeviceType () const
 Device type for this operation.
virtual std::size_t getStateMemorySize () const
 Returns the number of bytes of state memory allocated by this operation.
virtual bool isBuilt () const
 Whether build() completed successfully for a concrete input shape.
virtual bool isEvalMode () const
 Query whether operation is configured for training.
virtual void setGradients (ITensor *weight_grad, ITensor *bias_grad)
 Bind module-owned gradient tensors to the operation.
virtual void setParameters (ITensor *weight, ITensor *bias)
 Bind module-owned parameter tensors to the operation.
virtual void setTrainingMode (TrainingMode training_mode)
 Configure operation training-mode behavior.

Private Attributes

SwigluConfig config_
CudaExecutionContextcontext_
Detail::cuda_swiglu_impl< NativeTypeimpl_

Additional Inherited Members

Static Public Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput >
static constexpr TensorDataType data_type
static constexpr DeviceType device_type
Static Protected Member Functions inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision >
static const TensorInputTypeasInputTensor (const ITensor &t)
static TensorOutputTypeasOutputTensor (ITensor &t)
Protected Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput >
bool is_built_
TrainingMode training_mode_

Member Typedef Documentation

◆ CudaExecutionContext

template<TensorDataType TPrecision>
using Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TPrecision >::CudaExecutionContext = ExecutionContext<DeviceType::Cuda>

◆ MR

◆ NativeType

template<TensorDataType TPrecision>
using Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TPrecision >::NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TPrecision>::device_type

◆ TensorType

template<TensorDataType TPrecision>
using Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TPrecision >::TensorType = Tensor<TPrecision, MR>

◆ UnaryOperationBase

template<TensorDataType TPrecision>
using Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TPrecision >::UnaryOperationBase = UnaryOperation<DeviceType::Cuda, TPrecision>

Constructor & Destructor Documentation

◆ CudaSwigluOp()

template<TensorDataType TPrecision>
Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TPrecision >::CudaSwigluOp ( IExecutionContext * context,
const SwigluConfig & config )
inline

Member Function Documentation

◆ backward()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TPrecision >::backward ( const ITensor & input,
const ITensor & output_grad,
ITensor & input_grad ) const
inlineoverridevirtual

Backward pass: compute gradient wrt input given output gradient.

Signature ordered as (input, output_grad, input_grad) to match module and operation implementations across the codebase.

Implements Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision >.

◆ forward()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TPrecision >::forward ( const ITensor & input,
ITensor & output ) const
inlineoverridevirtual

Forward pass: compute output = f(input).

Implementations should accept polymorphic ITensor references and may use the typed aliases / helpers to obtain typed tensor references.

Implements Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision >.

◆ getName()

template<TensorDataType TPrecision>
std::string Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TPrecision >::getName ( ) const
inlineoverridevirtual

Human-readable operation name.

Implements Mila::Dnn::Compute::Operation< TDeviceType, TInput >.

◆ getOperationType()

template<TensorDataType TPrecision>
OperationType Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TPrecision >::getOperationType ( ) const
inlineoverridevirtual

Member Data Documentation

◆ config_

template<TensorDataType TPrecision>
SwigluConfig Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TPrecision >::config_
private

◆ context_

template<TensorDataType TPrecision>
CudaExecutionContext* Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TPrecision >::context_
private

◆ impl_


The documentation for this class was generated from the following file:
  • /__w/Mila/Mila/Mila/Src/Dnn/Compute/Devices/Cuda/Operations/Activations/Swiglu/CudaSwigluOp.ixx