Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Tensor.Partitioning.ixx File Reference
#include <stdexcept>
#include <string>
#include <cstdint>
import Dnn.TensorTypes;

Classes

struct  Mila::Dnn::AxisPartition
 Information about axis partitioning of a tensor. More...
struct  Mila::Dnn::MultiAxisPartition
 Multi-axis partition for normalization over trailing dimensions. More...

Namespaces

namespace  Mila
 Mila main API namespace.
namespace  Mila::Dnn

Functions

AxisPartition Mila::Dnn::computeAxisPartition (const shape_t &shape, dim_t axis, const char *op_name="Operation")
 Normalize and validate an axis, then compute partition sizes.
MultiAxisPartition Mila::Dnn::computeNormalizedShapePartition (const shape_t &shape, const shape_t &normalized_shape, const char *op_name="Operation")
 Compute partition for normalization over trailing dimensions.
int64_t Mila::Dnn::computeNumElements (const shape_t &shape)
 Compute total number of elements in a tensor shape.
void Mila::Dnn::validateTensorSize (const shape_t &shape, int64_t expected_size, const char *tensor_name="tensor", const char *op_name="Operation")
 Validate that a tensor has the expected number of elements.