Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::MultiHeadAttentionConfig Class Referenceexport

Configuration class for MultiHeadAttention module. More...

Inheritance diagram for Mila::Dnn::MultiHeadAttentionConfig:
Collaboration diagram for Mila::Dnn::MultiHeadAttentionConfig:

Public Member Functions

 MultiHeadAttentionConfig (size_t embedding_dim, size_t num_heads)
 Constructor with required parameters.
 
float getDropout () const
 Get the dropout rate.
 
size_t getEmbeddingDim () const
 Get the embedding dimension.
 
const std::vector< size_t > & getInputShape () const
 Get the input shape.
 
size_t getNumHeads () const
 Get the number of attention heads.
 
float getScaleFactor () const
 Get the attention scaling factor.
 
bool useCausalMask () const
 Check if causal masking is enabled.
 
bool useSeparateProjections () const
 Check if using separate projection matrices.
 
void validate () const
 Validate configuration parameters.
 
MultiHeadAttentionConfigwithCausalMask (bool causal)
 Configure whether to use causal attention mask.
 
MultiHeadAttentionConfigwithDropout (float dropout)
 Set the dropout rate for attention weights.
 
MultiHeadAttentionConfigwithInputShape (const std::vector< size_t > &input_shape)
 Set the input shape for the attention module.
 
MultiHeadAttentionConfigwithScaleFactor (float scale_factor)
 Set the scaling factor for attention logits.
 
MultiHeadAttentionConfigwithSeparateProjections (bool separate_projections)
 Configure whether to use separate projection matrices for query, key, and value.
 
- Public Member Functions inherited from Mila::Dnn::ComponentConfig
virtual ~ComponentConfig ()=default
 Virtual destructor to support proper polymorphic destruction.
 
const std::string & getName () const
 Gets the configured component name.
 
ComputePrecision::Policy getPrecision () const
 Gets the configured precision policy.
 
bool isTraining () const
 Gets the configured training mode.
 
template<typename Self >
auto & withName (this Self &&self, std::string name)
 Sets the name of the component with fluent interface.
 
template<typename Self >
auto & withPrecision (this Self &&self, ComputePrecision::Policy policy)
 Sets the compute precision policy with fluent interface.
 
template<typename Self >
auto & withTraining (this Self &&self, bool is_training)
 Sets the training mode with fluent interface.
 

Private Attributes

float dropout_ = 0.0f
 
size_t embedding_dim_
 
std::vector< size_t > input_shape_
 
size_t num_heads_
 
float scale_factor_ = 1.0f
 
bool separate_projections_ = true
 
bool use_causal_mask_ = false
 

Additional Inherited Members

- Protected Attributes inherited from Mila::Dnn::ComponentConfig
bool is_training_ = false
 Training mode flag, defaults to false (inference mode)
 
std::string name_ = "unnamed"
 Component name, defaults to "unnamed" if not explicitly set.
 
ComputePrecision::Policy precision_ = ComputePrecision::Policy::Auto
 Precision policy for computation, defaults to Auto.
 

Detailed Description

Configuration class for MultiHeadAttention module.

Constructor & Destructor Documentation

◆ MultiHeadAttentionConfig()

Mila::Dnn::MultiHeadAttentionConfig::MultiHeadAttentionConfig ( size_t  embedding_dim,
size_t  num_heads 
)
inline

Constructor with required parameters.

Parameters
embedding_dimThe embedding dimension size
num_headsThe number of attention heads

Member Function Documentation

◆ getDropout()

float Mila::Dnn::MultiHeadAttentionConfig::getDropout ( ) const
inline

Get the dropout rate.

Here is the caller graph for this function:

◆ getEmbeddingDim()

size_t Mila::Dnn::MultiHeadAttentionConfig::getEmbeddingDim ( ) const
inline

Get the embedding dimension.

Here is the caller graph for this function:

◆ getInputShape()

const std::vector< size_t > & Mila::Dnn::MultiHeadAttentionConfig::getInputShape ( ) const
inline

Get the input shape.

Here is the caller graph for this function:

◆ getNumHeads()

size_t Mila::Dnn::MultiHeadAttentionConfig::getNumHeads ( ) const
inline

Get the number of attention heads.

Here is the caller graph for this function:

◆ getScaleFactor()

float Mila::Dnn::MultiHeadAttentionConfig::getScaleFactor ( ) const
inline

Get the attention scaling factor.

Here is the caller graph for this function:

◆ useCausalMask()

bool Mila::Dnn::MultiHeadAttentionConfig::useCausalMask ( ) const
inline

Check if causal masking is enabled.

Here is the caller graph for this function:

◆ useSeparateProjections()

bool Mila::Dnn::MultiHeadAttentionConfig::useSeparateProjections ( ) const
inline

Check if using separate projection matrices.

Here is the caller graph for this function:

◆ validate()

void Mila::Dnn::MultiHeadAttentionConfig::validate ( ) const
inlinevirtual

Validate configuration parameters.

Exceptions
std::invalid_argumentIf validation fails

Reimplemented from Mila::Dnn::ComponentConfig.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ withCausalMask()

MultiHeadAttentionConfig & Mila::Dnn::MultiHeadAttentionConfig::withCausalMask ( bool  causal)
inline

Configure whether to use causal attention mask.

Parameters
causalTrue to use causal masking (for decoder/autoregressive models)
Returns
MultiHeadAttentionConfig& Reference to this for method chaining

◆ withDropout()

MultiHeadAttentionConfig & Mila::Dnn::MultiHeadAttentionConfig::withDropout ( float  dropout)
inline

Set the dropout rate for attention weights.

Parameters
dropoutDropout probability (0.0 to 1.0)
Returns
MultiHeadAttentionConfig& Reference to this for method chaining

◆ withInputShape()

MultiHeadAttentionConfig & Mila::Dnn::MultiHeadAttentionConfig::withInputShape ( const std::vector< size_t > &  input_shape)
inline

Set the input shape for the attention module.

Parameters
input_shapeVector containing input tensor dimensions [batch_size, seq_len, embedding_dim]
Returns
MultiHeadAttentionConfig& Reference to this for method chaining

◆ withScaleFactor()

MultiHeadAttentionConfig & Mila::Dnn::MultiHeadAttentionConfig::withScaleFactor ( float  scale_factor)
inline

Set the scaling factor for attention logits.

Parameters
scale_factorScaling factor (typically 1/sqrt(head_dim))
Returns
MultiHeadAttentionConfig& Reference to this for method chaining

◆ withSeparateProjections()

MultiHeadAttentionConfig & Mila::Dnn::MultiHeadAttentionConfig::withSeparateProjections ( bool  separate_projections)
inline

Configure whether to use separate projection matrices for query, key, and value.

Parameters
separate_projectionsTrue to use separate projections, false to use a single matrix
Returns
MultiHeadAttentionConfig& Reference to this for method chaining

Member Data Documentation

◆ dropout_

float Mila::Dnn::MultiHeadAttentionConfig::dropout_ = 0.0f
private

◆ embedding_dim_

size_t Mila::Dnn::MultiHeadAttentionConfig::embedding_dim_
private

◆ input_shape_

std::vector<size_t> Mila::Dnn::MultiHeadAttentionConfig::input_shape_
private

◆ num_heads_

size_t Mila::Dnn::MultiHeadAttentionConfig::num_heads_
private

◆ scale_factor_

float Mila::Dnn::MultiHeadAttentionConfig::scale_factor_ = 1.0f
private

◆ separate_projections_

bool Mila::Dnn::MultiHeadAttentionConfig::separate_projections_ = true
private

◆ use_causal_mask_

bool Mila::Dnn::MultiHeadAttentionConfig::use_causal_mask_ = false
private

The documentation for this class was generated from the following file: