Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Dnn.Components.Linear Module Reference

Exported Modules

module  Dnn.TensorTypes
module  Serialization.Tensor
module  Compute.Device
module  Compute.DeviceId
module  Dnn.Tensor
module  Dnn.Components.LinearConfig
module  Dnn.ComponentType
module  Logging.Logger
module  Dnn.Quantization.Weight.Policies
module  Dnn.TensorDataTypeTraits
module  Serialization.ModelArchive
module  Compute.ExecutionContextFactory
module  Dnn.Component
module  Compute.CpuMemoryResource
module  Compute.OperationType
module  Compute.DeviceType
module  Dnn.ITensor
module  Serialization.Mode
module  Compute.IExecutionContext
module  Dnn.TensorDataType
module  nlohmann.json
module  Compute.MemoryResource
module  Dnn.TensorHelpers
module  Compute.DeviceTypeTraits
module  Dnn.TensorOps
module  Dnn.TensorOps
module  Compute.OperationTraits

Classes

class  Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >
 Device-templated fully connected (linear) component. More...

Typedefs

using ComponentBase = Component<TDeviceType, TComputePrecision>
using Mila::Dnn::json = nlohmann::json
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using OpType = typename OperationTraits<OperationType::LinearOp, TDeviceType, TComputePrecision, TWeightQuant>::type
using TensorType = Tensor<TComputePrecision, MR>
using WeightScaleTensorType = Tensor<TWeightQuant::kScaleDtype, MR>
using WeightTensorType = Tensor<kWeightDtype, MR>

Functions

 Linear (const std::string &name, const LinearConfig &config, std::optional< DeviceId > device_id=std::nullopt)
 Construct a Linear component.
 ~Linear () override=default
TensorTypebackward (const TensorType &input, const TensorType &output_grad)
 Perform backward pass.
void createOperation ()
 Instantiate the backend compute operation via compile-time traits dispatch.
TensorTypeforward (const TensorType &input)
 Perform forward pass: output = input * weight^T + bias.
const LinearConfiggetConfig () const noexcept
DeviceId getDeviceId () const override
 Get the compute device id associated with this component.
std::vector< ITensor * > getGradients () const override
 Return non-owning pointers to parameter gradient tensors.
MemoryStats getMemoryStats () const override
 Return the current memory allocation breakdown for this component.
std::vector< ITensor * > getParameters () const override
 Return non-owning pointers to parameter tensors.
const ComponentType getType () const override
 Get the component type identifier.
bool hasBias () const noexcept
void initializeGradients ()
void initializeParameters (const BuildContext &context)
void loadParameter (const std::string &name, const ITensorBlob &blob) override
 Load a named parameter from a serialized blob.
void onBuilding (const BuildContext &context) override
 Hook invoked by build() to allocate component buffers.
void onExecutionContextSet () override
 Lifecycle hook: Called immediately after ExecutionContext is set.
void onTrainingModeChanging (TrainingMode mode) override
 Hook called before TrainingMode transitions.
size_t parameterCount () const override
 Return number of trainable parameters.
void save_ (ModelArchive &archive, SerializationMode mode) const override
 Save component state to a ModelArchive.
void synchronize () override
 Wait for outstanding device work submitted by this component.
std::string toString () const override
 Produce a short, human-readable description of the component.
void validateBuildContext (const BuildContext &context) const
void validateInputShape (const shape_t &input_shape) const
void zeroGradients () override
 Clear all model-owned gradients for this component.

Variables

std::shared_ptr< TensorTypebias_ { nullptr }
std::shared_ptr< TensorTypebias_grad_ { nullptr }
LinearConfig config_
std::unique_ptr< TensorTypeinput_grad_ { nullptr }
static constexpr bool kIsQuantized = TWeightQuant::kIsQuantized
static constexpr TensorDataType kWeightDtype
shape_t leading_shape_
std::shared_ptr< OpTypeoperation_ { nullptr }
std::unique_ptr< TensorTypeoutput_ { nullptr }
std::unique_ptr< TensorTypeoutput_view_ { nullptr }
std::unique_ptr< IExecutionContextowned_exec_context_ { nullptr }
std::shared_ptr< WeightTensorTypeweight_ { nullptr }
std::shared_ptr< TensorTypeweight_grad_ { nullptr }
std::unique_ptr< WeightScaleTensorTypeweight_scales_ { nullptr }

Files

file  /__w/Mila/Mila/Mila/Src/Dnn/Components/Linear/Linear.ixx
 Device-templated Linear (fully connected) component.