Mila
Deep Neural Network Library
|
Configuration class for MultiHeadAttention module. More...
Public Member Functions | |
MultiHeadAttentionConfig (size_t embedding_dim, size_t num_heads) | |
Constructor with required parameters. | |
float | getDropout () const |
Get the dropout rate. | |
size_t | getEmbeddingDim () const |
Get the embedding dimension. | |
const std::vector< size_t > & | getInputShape () const |
Get the input shape. | |
size_t | getNumHeads () const |
Get the number of attention heads. | |
float | getScaleFactor () const |
Get the attention scaling factor. | |
bool | useCausalMask () const |
Check if causal masking is enabled. | |
bool | useSeparateProjections () const |
Check if using separate projection matrices. | |
void | validate () const |
Validate configuration parameters. | |
MultiHeadAttentionConfig & | withCausalMask (bool causal) |
Configure whether to use causal attention mask. | |
MultiHeadAttentionConfig & | withDropout (float dropout) |
Set the dropout rate for attention weights. | |
MultiHeadAttentionConfig & | withInputShape (const std::vector< size_t > &input_shape) |
Set the input shape for the attention module. | |
MultiHeadAttentionConfig & | withScaleFactor (float scale_factor) |
Set the scaling factor for attention logits. | |
MultiHeadAttentionConfig & | withSeparateProjections (bool separate_projections) |
Configure whether to use separate projection matrices for query, key, and value. | |
![]() | |
virtual | ~ComponentConfig ()=default |
Virtual destructor to support proper polymorphic destruction. | |
const std::string & | getName () const |
Gets the configured component name. | |
ComputePrecision::Policy | getPrecision () const |
Gets the configured precision policy. | |
bool | isTraining () const |
Gets the configured training mode. | |
template<typename Self > | |
auto & | withName (this Self &&self, std::string name) |
Sets the name of the component with fluent interface. | |
template<typename Self > | |
auto & | withPrecision (this Self &&self, ComputePrecision::Policy policy) |
Sets the compute precision policy with fluent interface. | |
template<typename Self > | |
auto & | withTraining (this Self &&self, bool is_training) |
Sets the training mode with fluent interface. | |
Private Attributes | |
float | dropout_ = 0.0f |
size_t | embedding_dim_ |
std::vector< size_t > | input_shape_ |
size_t | num_heads_ |
float | scale_factor_ = 1.0f |
bool | separate_projections_ = true |
bool | use_causal_mask_ = false |
Additional Inherited Members | |
![]() | |
bool | is_training_ = false |
Training mode flag, defaults to false (inference mode) | |
std::string | name_ = "unnamed" |
Component name, defaults to "unnamed" if not explicitly set. | |
ComputePrecision::Policy | precision_ = ComputePrecision::Policy::Auto |
Precision policy for computation, defaults to Auto. | |
Configuration class for MultiHeadAttention module.
|
inline |
Constructor with required parameters.
embedding_dim | The embedding dimension size |
num_heads | The number of attention heads |
|
inline |
Get the dropout rate.
|
inline |
Get the embedding dimension.
|
inline |
Get the input shape.
|
inline |
Get the number of attention heads.
|
inline |
Get the attention scaling factor.
|
inline |
Check if causal masking is enabled.
|
inline |
Check if using separate projection matrices.
|
inlinevirtual |
Validate configuration parameters.
std::invalid_argument | If validation fails |
Reimplemented from Mila::Dnn::ComponentConfig.
|
inline |
Configure whether to use causal attention mask.
causal | True to use causal masking (for decoder/autoregressive models) |
|
inline |
Set the dropout rate for attention weights.
dropout | Dropout probability (0.0 to 1.0) |
|
inline |
Set the input shape for the attention module.
input_shape | Vector containing input tensor dimensions [batch_size, seq_len, embedding_dim] |
|
inline |
Set the scaling factor for attention logits.
scale_factor | Scaling factor (typically 1/sqrt(head_dim)) |
|
inline |
Configure whether to use separate projection matrices for query, key, and value.
separate_projections | True to use separate projections, false to use a single matrix |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |