Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant > Class Template Referenceexport

Device-templated fully connected (linear) component. More...

Inheritance diagram for Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >:
Collaboration diagram for Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >:

Public Types

using ComponentBase = Component<TDeviceType, TComputePrecision>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using OpType = typename OperationTraits<OperationType::LinearOp, TDeviceType, TComputePrecision, TWeightQuant>::type
using TensorType = Tensor<TComputePrecision, MR>
using WeightScaleTensorType = Tensor<TWeightQuant::kScaleDtype, MR>
using WeightTensorType = Tensor<kWeightDtype, MR>

Public Member Functions

 Linear (const std::string &name, const LinearConfig &config, std::optional< DeviceId > device_id=std::nullopt)
 Construct a Linear component.
 ~Linear () override=default
TensorTypebackward (const TensorType &input, const TensorType &output_grad)
 Perform backward pass.
TensorTypeforward (const TensorType &input)
 Perform forward pass: output = input * weight^T + bias.
const LinearConfiggetConfig () const noexcept
DeviceId getDeviceId () const override
 Get the compute device id associated with this component.
std::vector< ITensor * > getGradients () const override
 Return non-owning pointers to parameter gradient tensors.
MemoryStats getMemoryStats () const override
 Return the current memory allocation breakdown for this component.
std::vector< ITensor * > getParameters () const override
 Return non-owning pointers to parameter tensors.
const ComponentType getType () const override
 Get the component type identifier.
bool hasBias () const noexcept
void loadParameter (const std::string &name, const ITensorBlob &blob) override
 Load a named parameter from a serialized blob.
size_t parameterCount () const override
 Return number of trainable parameters.
void save_ (ModelArchive &archive, SerializationMode mode) const override
 Save component state to a ModelArchive.
void synchronize () override
 Wait for outstanding device work submitted by this component.
std::string toString () const override
 Produce a short, human-readable description of the component.
void zeroGradients () override
 Clear all model-owned gradients for this component.
Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TComputePrecision >
 Component (const std::string &name)
 Construct component with required name identifier.
virtual ~Component ()=default
virtual void build (const BuildContext &context) final
 Build the component with the provided BuildContext (canonical overload).
const std::string getName () const
 Get the component's name identifier.
virtual std::vector< std::string > getParameterNames () const
 List all available parameter names for this component.
RuntimeMode getRuntimeMode () const noexcept
 Convenience accessor — true if currently in Eval mode.
TrainingMode getTrainingMode () const noexcept
 The current runtime behavioral mode of this Component.
virtual bool isBuilt () const final
 Returns true if build() has completed successfully.
bool isInferenceMode () const noexcept
bool isTrainingMode () const noexcept
void setTrainingMode (TrainingMode mode)
 Set the runtime behavioral mode for this Component.

Static Public Attributes

static constexpr bool kIsQuantized = TWeightQuant::kIsQuantized
static constexpr TensorDataType kWeightDtype

Protected Member Functions

void onBuilding (const BuildContext &context) override
 Hook invoked by build() to allocate component buffers.
void onExecutionContextSet () override
 Lifecycle hook: Called immediately after ExecutionContext is set.
void onTrainingModeChanging (TrainingMode mode) override
 Hook called before TrainingMode transitions.
Protected Member Functions inherited from Mila::Dnn::Component< TDeviceType, TComputePrecision >
IExecutionContextgetExecutionContext () const
 Get the shared execution context.
bool hasExecutionContext () const noexcept
 Check if execution context has been set.
void loadParameterFromBlob (const std::string &param_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape)
 Load a tensor blob into a parameter tensor with validation.
void setExecutionContext (IExecutionContext *context)
 Set the execution context for this component.

Private Member Functions

void createOperation ()
 Instantiate the backend compute operation via compile-time traits dispatch.
void initializeGradients ()
void initializeParameters (const BuildContext &context)
void validateBuildContext (const BuildContext &context) const
void validateInputShape (const shape_t &input_shape) const

Private Attributes

std::shared_ptr< TensorTypebias_ { nullptr }
std::shared_ptr< TensorTypebias_grad_ { nullptr }
LinearConfig config_
std::unique_ptr< TensorTypeinput_grad_ { nullptr }
shape_t leading_shape_
std::shared_ptr< OpTypeoperation_ { nullptr }
std::unique_ptr< TensorTypeoutput_ { nullptr }
std::unique_ptr< TensorTypeoutput_view_ { nullptr }
std::unique_ptr< IExecutionContextowned_exec_context_ { nullptr }
std::shared_ptr< WeightTensorTypeweight_ { nullptr }
std::shared_ptr< TensorTypeweight_grad_ { nullptr }
std::unique_ptr< WeightScaleTensorTypeweight_scales_ { nullptr }

Additional Inherited Members

Static Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TComputePrecision >
static constexpr DeviceType getDeviceType ()
 Compile-time device type for this component instance.
static constexpr TensorDataType getPrecision () noexcept
 Compile-time tensor precision for this component instance.
Protected Attributes inherited from Mila::Dnn::Component< TDeviceType, TComputePrecision >
BuildContext build_context_
 The BuildContext stored at build time.

Detailed Description

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
requires PrecisionSupportedOnDevice<TComputePrecision, TDeviceType>
class Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >

Device-templated fully connected (linear) component.

Delegates compute to a device-specific operation resolved at compile time via OperationTraits<LinearOp, TDeviceType, TComputePrecision, TWeightQuant>. TWeightQuant defaults to NoWeightQuant for unquantized paths.

When TWeightQuant::kIsQuantized is true, the weight tensor is allocated at the reduced-precision storage dtype (kWeightDtype = TWeightQuant::kStorageDtype) rather than TComputePrecision. Per-channel FP32 scale factors (weight_scales_) are allocated alongside the weight tensor and bound to the backend operation via setWeightScales() before the first forward pass. The backend operation receives both the quantized weight tensor and its scales and is responsible for dequantization during the GEMM.

Weight quantization is performed once at model load time (quantize-on-load) during loadParameter(). The source checkpoint blob is always at TComputePrecision.

Template Parameters
TDeviceTypeTarget device.
TComputePrecisionActivation and accumulation precision.
TWeightQuantWeight quantization policy. Must satisfy WeightQuantPolicy. Defaults to NoWeightQuant (identity — no quantization).

Constructor & Destructor Documentation

◆ Linear()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::Linear ( const std::string & name,
const LinearConfig & config,
std::optional< DeviceId > device_id = std::nullopt )
inlineexplicitexport

Construct a Linear component.

Constructs with a name and configuration. If device_id is provided, the component creates and owns an ExecutionContext (standalone mode) and registers it with the base Component via setExecutionContext(). If device_id is not provided, the component expects a shared ExecutionContext to be provided later via setExecutionContext().

Parameters
nameComponent name.
configLayer configuration (validated on construction).
device_idOptional device identifier. When present the component creates an owned ExecutionContext for the device.
Exceptions
std::invalid_argumentif config is invalid or device type mismatches.
std::runtime_errorif ExecutionContext creation fails.

◆ ~Linear()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::~Linear ( )
overrideexportdefault

Member Function Documentation

◆ backward()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
TensorType & Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::backward ( const TensorType & input,
const TensorType & output_grad )
inlineexport

Perform backward pass.

Pre-zeros the component-owned input gradient buffer, then delegates to the backend operation. The backend accumulates weight and bias gradients into the buffers bound via setGradients() using += semantics; pre-zeroing ensures clean gradient state across calls.

Not supported on quantized paths (kIsQuantized == true) — the backend operation will throw std::logic_error if backward is attempted.

Parameters
inputOriginal forward-pass input tensor.
output_gradUpstream gradient tensor (same shape as the forward output).
Returns
Reference to the component-owned input gradient tensor.
Exceptions
std::runtime_errorif the component has not been built.
std::runtime_errorif called while in inference (eval) mode.

◆ createOperation()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::createOperation ( )
inlineexportprivate

Instantiate the backend compute operation via compile-time traits dispatch.

OpType is resolved by OperationTraits at instantiation time — no registry lookup, no string key, no runtime hash map. A missing specialization is a compile error.

◆ forward()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
TensorType & Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::forward ( const TensorType & input)
inlineexport

Perform forward pass: output = input * weight^T + bias.

Delegates to the backend operation using the component-owned output buffer allocated at build time. When the runtime input shape differs from the build-time shape (e.g. a shorter decode sequence vs. the prefill shape), a lightweight view over the output buffer is returned that reflects the true output shape without reallocating device memory.

Parameters
inputInput tensor (device-bound, rank >= 2). The last dimension must equal the configured input feature count.
Returns
Reference to the output tensor or a shape-adjusted view of it.
Exceptions
std::runtime_errorif the component has not been built.
std::invalid_argumentif the input feature dimension does not match the config.

◆ getConfig()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
const LinearConfig & Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::getConfig ( ) const
inlineexportnoexcept

◆ getDeviceId()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
DeviceId Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::getDeviceId ( ) const
inlineoverrideexportvirtual

Get the compute device id associated with this component.

Must return the device on which parameters and operations execute.

Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ getGradients()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
std::vector< ITensor * > Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::getGradients ( ) const
inlineoverrideexportvirtual

Return non-owning pointers to parameter gradient tensors.

Only valid when isTraining() is true.

Exceptions
std::runtime_errorif called when not in training mode or before the component has been built.

Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ getMemoryStats()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
MemoryStats Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::getMemoryStats ( ) const
inlineoverrideexportvirtual

Return the current memory allocation breakdown for this component.

Reflects allocations at the moment of the call. The returned stats naturally track the component lifecycle:

After construction — parameters only After build( Inference ) — parameters + T=1 state buffers After build( Training ) — parameters + T=full state buffers After setEvaluation( false ) — parameters + state + gradients

For CompositeComponent and Network, the returned stats are the recursive aggregate of all child components.

May be called at any time — no lifecycle preconditions.

Returns
MemoryStats reflecting current allocations.

Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ getParameters()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
std::vector< ITensor * > Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::getParameters ( ) const
inlineoverrideexportvirtual

Return non-owning pointers to parameter tensors.

The returned tensor pointers remain valid for the lifetime of the component. Order should be canonical (weights before biases).

Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ getType()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
const ComponentType Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::getType ( ) const
inlineoverrideexportvirtual

Get the component type identifier.

Used for serialization and runtime type identification.

Returns
Component type enum value.

Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ hasBias()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
bool Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::hasBias ( ) const
inlineexportnoexcept

◆ initializeGradients()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::initializeGradients ( )
inlineexportprivate

◆ initializeParameters()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::initializeParameters ( const BuildContext & context)
inlineexportprivate

◆ loadParameter()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::loadParameter ( const std::string & name,
const ITensorBlob & blob )
inlineoverrideexportvirtual

Load a named parameter from a serialized blob.

Weight loading dispatches at compile time on kIsQuantized:

  • Unquantized path (kIsQuantized == false): the blob is validated against TComputePrecision and copied directly into the weight tensor via loadParameterFromBlob.
  • Quantized path (kIsQuantized == true): the blob dtype must be TComputePrecision (the full-precision source type). The backend operation's quantize() method performs per-channel absmax scale computation, quantizes weights from TComputePrecision to kWeightDtype (e.g. BF16 → FP8_E4M3), and uploads both the quantized weights and FP32 scales to device. The weight_scales_ tensor was pre-allocated in initializeParameters() and its device pointer was already bound to the operation in onBuilding() via setWeightScales() — quantize() writes directly into that allocation.

Bias is always stored and loaded at TComputePrecision regardless of TWeightQuant.

Parameters
nameParameter name: "weight" or "bias".
blobSerialized tensor blob from PretrainedModelReader.
Exceptions
std::invalid_argumentif the blob dtype does not match the expected source precision, or if the blob shape does not match the config.

Reimplemented from Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ onBuilding()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::onBuilding ( const BuildContext & config)
inlineoverrideexportprotectedvirtual

Hook invoked by build() to allocate component buffers.

Receives the stored BuildContext. Implementations must use config.allocationSeqLen() when sizing output buffers — this is the single call that makes Inference and Training allocate the correct buffer sizes automatically without per-component logic.

// Example — Linear component:
shape_t out_shape =
{
config.batchSize(),
config.allocationSeqLen(), // 1 for Inference, T for Training
config_.getOutputFeatures()
};
output_ = std::make_unique<TensorType>( device, out_shape,
this->getName() + ".output" );
LinearConfig config_
Definition Linear.ixx:599
TensorShape shape_t
Row-major shape descriptor for tensor dimensional sizes.
Definition Tensor.Types.ixx:143

The default implementation forwards to the legacy onBuilding( const shape_t& ) overload for backwards compatibility. New components should override this overload directly.

Note
Do not call build() or onBuilding() from within this hook.
Implementations should either succeed fully or leave no partial state, as a failed build() may be retried.
Parameters
configBuild-time configuration. Use config.allocationSeqLen() to obtain the correct output buffer sequence dimension.

Reimplemented from Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ onExecutionContextSet()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::onExecutionContextSet ( )
inlineoverrideexportprotectedvirtual

Lifecycle hook: Called immediately after ExecutionContext is set.

Override this to perform initialization that requires a valid ExecutionContext. At the time this is called, getExecutionContext() is guaranteed to return a valid context.

Common uses:

  • Composite components: Create and configure child components.
  • Device resource allocation: Query device capabilities.

Default implementation does nothing.

Exceptions
Anyexception thrown will cause setExecutionContext() to fail and restore the component to a "context not set" state.

Reimplemented from Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ onTrainingModeChanging()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::onTrainingModeChanging ( TrainingMode mode)
inlineoverrideexportprotectedvirtual

Hook called before TrainingMode transitions.

Called by setTrainingMode() after validation and lock acquisition, before the internal state is updated. Derived classes override to respond to the transition — e.g. zeroing gradient buffers on transition to Eval, or re-enabling dropout on transition to Training.

The default implementation is a no-op.

Parameters
modeThe incoming TrainingMode.

Reimplemented from Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ parameterCount()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
size_t Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::parameterCount ( ) const
inlineoverrideexportvirtual

Return number of trainable parameters.

For leaf components this is the element count of owned parameter tensors. CompositeComponent and Network implementations should return the recursive aggregate across all children.

Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ save_()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::save_ ( ModelArchive & archive,
SerializationMode mode ) const
inlineoverrideexportvirtual

Save component state to a ModelArchive.

Writes a "meta.json" blob with component type and name, a "config.json" blob with input/output feature dimensions and bias flag, and raw tensor blobs for the weight and (if present) bias parameters under "tensors/".

On CUDA devices, each tensor is copied to a temporary host buffer before writing. Weight is serialized at its storage dtype (kWeightDtype), which equals kWeightDtype = TWeightQuant::kStorageDtype on the quantized path.

Parameters
archiveModelArchive to write to (scoped by caller).
modeSerialization mode (currently unused; reserved for future use).

Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ synchronize()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::synchronize ( )
inlineoverrideexportvirtual

Wait for outstanding device work submitted by this component.

On CPU this may be a no-op. Use to ensure results are visible to the host or to measure synchronous timings.

Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ toString()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
std::string Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::toString ( ) const
inlineoverrideexportvirtual

Produce a short, human-readable description of the component.

Implementations should keep output concise and avoid throwing.

Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ validateBuildContext()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::validateBuildContext ( const BuildContext & context) const
inlineexportprivate

◆ validateInputShape()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::validateInputShape ( const shape_t & input_shape) const
inlineexportprivate

◆ zeroGradients()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >::zeroGradients ( )
inlineoverrideexportvirtual

Clear all model-owned gradients for this component.

Default implementation is a no-op. Composite components should override to recurse to children. Leaf components should override to zero their parameter and activation gradients using device-aware helpers.

Reimplemented from Mila::Dnn::Component< TDeviceType, TComputePrecision >.


The documentation for this class was generated from the following file:
  • /__w/Mila/Mila/Mila/Src/Dnn/Components/Linear/Linear.ixx