High-Performance Matrix Multiplication for Machine Learning, Ray Tracing, and Beyond: The Five Procedures to Multiply Matrices
If you are reading this, you probably know how to multiply two matrices together, but if you haven’t taken a Linear Algebra course, you might not realize that that is just one way to do it. Without a Numerical Analysis course, you might not know how useful these different methods are.
Understanding what is meant by 'five procedures' and how they differ is useful for understanding and optimizing how large matrices or large numbers of matrices can be multiplied efficiently under different circumstances, for example, in ray tracing or neural networks, and on CPUs versus GPUs versus TPUs.
The five procedures all work on the same data, produce the same result, and perform the same internal calculations, but they differ in how they group matrix elements, how they are interpreted, and how they organize memory access and execution order.
The Procedures
To see what I mean, let's look at the definition of $\Z=\X\Y$. It can be expanded, or written out as:
$$z_{j,i} = \sum_{k} x_{j,k} y_{k,i}$$
which reads as "The ji-th cell of Z is the sum over all k of the product of the jk-th cell of X with the ki-th cell of Y." While this gives us a nice mathematical definition, it doesn't say anything about how to implement the loops over i, j, or k, nor does it say anything about the memory access pattern of an implementation.
We have implementation choices to make, and we'll group these choices into five procedures categorized by how we arrange the loops. These five procedures are normally presented in a slightly different order, but as this isn't a linear algebra course, I think it makes more sense to present the two most familiar special cases first, then the general case, and lastly two more important special cases.
1. Dot Products (Or Inner Products - The way you're probably familiar with)
Make a vector of a row in $\X$, a vector of a column of $\Y$, and the corresponding elements of $\Z$ are the dot products of every possible pair of these x and y vectors. So loop over all possible i and j pairs, and then $z_{ji} = \vec{x}_{j} \bullet \vec{y}_{i}$, where the dot product is calculated by accumulating the element-wise products of the vectors (ref).
Python (well, Numpy and Sympy) have a nice concise array slicing syntax and pretty printing functions. I'll consistently use i for column indices, j for row indices, and k as the common index. So implementing a simple example in Python (ignoring some housekeeping) we get:
X = sp.Matrix([
[1, 2, 1],
[3, 4, 3]
])
Y = sp.Matrix([
[5, 6],
[7, 8],
[3, 4]
])
def example1():
Z = sp.zeros(2, 2)
for j, i in [[0, 0], [1, 0], [0, 1], [1, 1]]:
Z[j, i] = X[j, :] * Y[:, i] # row-vector * col-vector
% python examples.py -n 1
⎡22 26⎤
Z = ⎢ ⎥
⎣52 62⎦
The full code for this example is in listing 1 at the end.
If we inline and expose all the loops, we can see the loop ordering is j-i-k:
for j in range(0, x_height):
for i in range(0, y_width):
for k in range(0, x_width):
Z[j, i] += X[j, k] * Y[k, i]
Note that we could swap the i and j loops giving an i-j-k loop ordering, which would transpose the pattern we access Z elements, but it is still a dot-product procedure.
Video 1 shows the memory access pattern for this procedure with X as a [$6\times 12$] matrix, Y as a [$12\times 16$] matrix, into Z as a [$6\times 16$] matrix, using a cache configured with 24 lines each one cell wide. This models, say, an FPU with 24 single float registers.
2. Tensor Products (More correctly: Outer Products)
Make a vector of a column in $\X$, a vector of a row of $\Y$, and $\Z$ is the accumulation of the matrices made from the tensor product of corresponding pairs of vectors. So loop over each column-row pair, compute the tensor product for each pair, and accumulate or sum over the resultant matrices, where the tensor product is the matrix produced by multiplying every possible pair of elements from the vectors (ref).
Our example then becomes:
def example2():
mats = [X[:, k] * Y[k, :] for k in [0, 1, 2]] # col-vector * row-vector
Z = sum(mats, sp.zeros(2, 2))
% python examples.py -n 2
⎡5 6 ⎤ ⎡14 16⎤ ⎡3 4 ⎤ ⎡22 26⎤
Z = ⎢ ⎥ + ⎢ ⎥ + ⎢ ⎥ = ⎢ ⎥
⎣15 18⎦ ⎣28 32⎦ ⎣9 12⎦ ⎣52 62⎦
Inlining the loops, we see that it is a k-j-i ordering:
for k in range(0, x_width):
for j in range(0, x_height):
for i in range(0, y_width):
Z[j, i] += X[j, k] * Y[k, i]
Note that we could swap the i and j loops giving an k-j-i loop ordering, but it is still a tensor-product procedure.
Video 2 shows the access pattern for this k-i-j ordering using the same configuration as above.
Notice where the word 'accumulate' appears in the two descriptions paragraphs for 1 and 2 above. In the first, each element is accumulated completely before moving on to the next, while in the second, the accumulation is spread out. If you ignore execution and memory patterns, all of these methods will have effectively the same complexity. However, as can be seen here, the order in which memory or registers are accessed is very different.
3. The third procedure is a generalization of these two patterns.
$\X$ and $\Y$ can be split into blocks or sub-matrices where the rules from procedures 1 or 2 can be applied to the blocks as if the blocks were numbers in an ordinary matrix. For example 3, let's say we can split up some matrix product into a $[2\times 3]$ block and a $[3\times 2]$ block and then apply the patterns from 1 and 2. We get:
$$\small\begin{align*}\equationOne\end{align*}$$
This looks exactly like an ordinary matrix product written out, except we've added lines to indicate that the symbols represent blocks or sub-matrices rather than scalars.
Each of the (sub-) matrix expressions has to follow the normal rules for matrix multiplication. For example, the blocks that make up $\mathbf{A}$ and $\X$ have to have the appropriate shapes to allow them to be multiplied just as if they were matrices.
This block-able structure underpins most of the optimizations you can do with matrices, and specifically two important ones:
The first is when large blocks of a matrix are zeros, such as in a block-diagonal matrix, or when a block is an identity matrix. In either case, whole chunks of computation can be entirely bypassed.
The second is tiling, which we'll discuss more fully at the end, but for now we can use it to see how these methods are characterized by loop or index ordering. If we inline all the loops for line 1 of the above equation (also labeled 'Dot') where both the blocking and sub-matrix multiplication are using the i-j-k ordering from procedure 1, we get:
for outer_i in range(0, n, tile_size):
for outer_j in range(0, n, tile_size):
for outer_k in range(0, n, tile_size):
for i in range(outer_i, outer_i + tile_size):
for j in range(outer_j, outer_j + tile_size):
for k in range(outer_k, outer_k + tile_size):
Z[i, j] += X[i, k] * Y[k, j]
Video 3 shows the access pattern for this tiling using the same configuration as above.
Whereas if we look at line two (labeled 'Tensor'), it is using the k-i-j or k-j-i ordering from procedure 2. Let's choose k-j-i and keep the i-j-k ordering for the sub-matrix multiplication:
for outer_k in range(0, n, tile_size):
for outer_j in range(0, n, tile_size):
for outer_i in range(0, n, tile_size):
for i in range(outer_i, outer_i + tile_size):
for j in range(outer_j, outer_j + tile_size):
for k in range(outer_k, outer_k + tile_size):
Z[i, j] += X[i, k] * Y[k, j]
As before, these are identical except for the ordering of the loops. For this example, we can elide the two i-index loops back into a single loop and recover the classic '5-do-loop' Tiled Matrix Multiply.
for outer_k in range(0, n, tile_size):
for outer_j in range(0, n, tile_size):
for i in range(0, n):
for j in range(outer_j, outer_j + tile_size):
for k in range(outer_k, outer_k + tile_size):
Z[i, j] += X[i, k] * Y[k, j]
The final two procedures are special cases of this blocking with useful interpretations.
Vectors and matrices are the principal objects used in the study of linear algebra. Linear algebra is the study of vector spaces, things that can be interpreted as vector spaces, and the transformations within and between vector spaces. Any matrix represents a transformation and can be applied by multiplication to an input vector to get a transformed vector as output. For example, consider rotations in some physics or graphics model, which take a 3-vector onto a 3-vector, or a linear layer in a neural network that takes a vector embedding in one layer onto the vector embedding in the next. These transformations meet the mathematical definition of a linear operator, and instead of writing them as matrices, we can write and use them as functions that transform one space into another. This gives us access to all the machinery and interpretations of functions if we want them.
We have two methods because matrix multiplication isn't commutative: multiplying $\A$ on the left of $\B$, which gives $\A\B$, is different from multiplying $\A$ on the right of $\B$, which gives $\B\A$. Additionally, the interpretation of row and column vectors depends on the system being described. For a deeper dive on this, try Wikipedia: Row and column spaces. For our purposes here, we'll just describe the procedures. In both, we'll leave one of the matrices unsplit, treat it as a linear-transform matrix-operator, and build a function from it; the other matrix we'll split into vectors.
4. Column Space
Each column of $\Z$ is the application of $\X$ as a left-operator to each column of $\Y$.
So $\vec{z}_{i}=X(\vec{y}_{i})=\X\vec{y}_{i}$ where $\vec{z}_{i}$ and $\vec{y}_{i}$ are column vectors of $\Z$ and $\Y$ respectively.
def example4():
left_op = make_left_operator(X)
Z = sp.zeros(2, 2)
for i in [0, 1]:
Z[:, i] = left_op(Y[:, i])
The loop order here is i-k-j.
5. Row Space
Each row of $\Z$ is the application of $\Y$ as a right-operator to each row of $\X$.
So $\vec{z}_{j}=Y(\vec{x}_{j})=\vec{x}_{j}\Y$ where $\vec{z}_{j}$ and $\vec{x}_{j}$ are row vectors of $\Z$ and $\X$ respectively.
def example5():
right_op = make_right_operator(Y)
Z = sp.zeros(2, 2) for j in [0, 1]:
Z[j, :] = right_op(X[j, :])
The loop order here is j-k-i.
Summarizing Access Patterns
Looking at the 3-loop procedures: there are six ways to order three loops. Procedures 1 and 2 account for four orderings because the i-j order doesn’t really change anything, and procedures 4 and 5 bring the total to six.
However, the important difference between these procedures lies in their memory access patterns, which we can compare by considering what happens as the different methods loop:
- Method 1: Overall, each cell in Z is accessed exactly once, but many cells of X and Y are accessed in each iteration.
- Method 2: Overall, each row of X and column of Y is accessed exactly once, but every cell of Z is accessed in each iteration.
- Method 4: Overall, each column of Z and column of Y is accessed exactly once, but every cell of X is accessed in each iteration.
- Method 5: Overall, each row of Z and row of X is accessed exactly once, but every cell of Y is accessed in each iteration.
Tiling
In real implementations, the multiplication procedure needs to account for memory caching and parallel execution to approach optimal performance. We can use the block-matrix structure shown in Procedure 3 to distribute these blocks—now commonly called tiles—across multiple threads, while arranging their size to fit within the appropriate cache line or memory page.
If you know the general shape of the matrix multiplication ahead of time, tiling allows you to construct a procedure that keeps all available execution units active, while minimizing delays due to spills or fills between memory hierarchy levels.
Consider parallelizing Example 4 above: The matrix operator $\X$ would need to be broadcast to every thread, while the column vectors can be scattered across threads and gathered back appropriately. This approach maps well to a GPU, which has execution units capable of handling complete 4-vector and $4\times 4$ matrix operations in hardware as single instructions.
Contrast this with the blocking structure in Example 3, which suggests scattering three independent sets of blocks, each computing a matrix that is then gathered and accumulated. This structure is well suited for a multicore CPU, where independent sets can be managed and kept separate.
Cache considerations are similar: If you look closely at the Y matrix in Video 1, you’ll notice a gap in the column, indicating that as it loops over that column, one of the cells is not currently loaded into a register and will need to be loaded for the multiplication to continue. This is problematic. I set it up to represent the worst-case scenario, where there aren’t quite enough registers available, resulting in each cell being used only once before making way for another value to be loaded.
Again contrast this with the tiled version of the same matrices in Video 3. In this configuration, every cell is used multiple times before being evicted, allowing more individual multiplications to occur between load operations. Modern processors can load registers quite quickly, but there’s still a multi-cycle overhead to do so, even assuming the entire computation fits within the processor's Level 1 cache.
However, this principle applies recursively for each cache level: If the tiles are slightly too large to fit entirely within a cache level, then each entry is used only once, and the computation has to wait while the next needed value is fetched from the next cache level.
Fortunately, tiling can be applied recursively as well. We can split matrices into tiles that fit neatly into the Level 2 cache, then divide those tiles so that each sub-tile fits neatly into the Level 1 cache, and then split the sub-tiles so they fit within the register space.
Video 4 shows a two-level system with a slightly larger matrix than before. The registers are configured as 24 single-cell registers, remaining red, as before. The cache has 24 lines of 8 cells each, shown in green.
For a large matrix, like those used in large language models, you can construct a procedure that first tiles the matrices over, say, GPU cores, then the L2 cache, then execution threads, then the L1 cache, and finally execution registers. Each level has its own tile size and can use a procedure that is optimal for that level.
The simulator Python code for these videos can be found on github. It can handle any size of matrix, with any tiling hierarchy, and with any number of cache levels. The simulation is quite quick; however, rendering takes forever.
Conclusions
The mathematical definition of matrix multiplication doesn’t care about the foibles of an implementation. It knows nothing about memory hierarchies or multi-core parallel systems—systems designed to perform many operations on small amounts of data per execution thread. Even systems with vector operations still assume the vectors are relatively small. Large matrix multiplication, however, makes the exact opposite assumption: it performs a single operation spread over extensive data. The issues this creates can be mitigated by using tiling and choosing appropriate loop orderings. By breaking the computation into smaller tile-computations, the overall procedure can be tailored to match the structure and assumptions of the hardware being used, even if that structure is complex.
While many applications can use a library that hides this complexity behind a convenient API, getting the match right is often challenging, and small changes can result in orders of magnitude differences in execution times. There are also plenty of cases where a library cannot help. Understanding these procedures and recognizing them when you see them is an essential tool for debugging or creating anything that relies on matrix multiplication.
\newcommand\A{\mathbf{A}} \newcommand\B{\mathbf{B}} \newcommand\X{\mathbf{X}} \newcommand\Y{\mathbf{Y}} \renewcommand\Z{\mathbf{Z}} \newcommand\equationOne{% \left[\begin{array}{c|c|c} A & B & C\\ \hline D & E & F \end{array}\right]\left[\begin{array}{c|c} X & Y\\ \hline Z & W\\ \hline U & V \end{array}\right] & =\left[\begin{array}{c|c} AX+BZ+CU & AY+BW+CV\\ \hline DX+EZ+FU & DY+EW+FV \end{array}\right] & \text{Dot (1)}\\ & =\left[\begin{array}{c|c} AX & AY\\ \hline DX & DY \end{array}\right]+\left[\begin{array}{c|c} BZ & BW\\ \hline EZ & EW \end{array}\right]+\left[\begin{array}{c|c} CU & CV\\ \hline FU & FV \end{array}\right] & \text{Tensor (2)} }%
Listing 1: The complete python code for the code examples above
# from numpy import zeros, array, dot, tensordot
# import numpy.typing as npt
from typing import Callable
import sympy as sp
from argparse import ArgumentParser
def doArgs():
parser = ArgumentParser(description="Matrix Multiplication Examples",)
parser.add_argument('-n', '--number', dest='n', type=int, required=True,
help='example number')
return parser.parse_args()
def tensor(x: sp.Matrix, y: sp.Matrix) -> sp.Matrix:
m = sp.tensorproduct(x, y)
sp.pprint(m)
return sp.tensorproduct(x, y)
def dot(x: sp.Matrix, y: sp.Matrix) -> float:
return x.dot(y)
def make_left_operator(m: sp.Matrix) -> Callable[[sp.Matrix], sp.Matrix]:
def prod(v: sp.Matrix) -> sp.Matrix:
return m * v
return prod
def make_right_operator(m: sp.Matrix) -> Callable[[sp.Matrix], sp.Matrix]:
def prod(v: sp.Matrix) -> sp.Matrix:
return v * m
return prod
X = sp.Matrix([
[1, 2, 1],
[3, 4, 3]
])
Y = sp.Matrix([
[5, 6],
[7, 8],
[3, 4]
])
sX, sY, sZ = sp.symbols('X,Y,Z')
def example1():
Z = sp.zeros(2, 2)
for j, i in [[0, 0], [1, 0], [0, 1], [1, 1]]:
Z[j, i] = X[j, :] * Y[:, i]
sp.pprint(sp.Eq(sZ, Z, evaluate=False))
def example2():
mats = [X[:, i] * Y[i, :] for i in [0, 1, 2]]
Z = sum(mats, sp.zeros(2, 2))
sp.pprint(sp.Eq(sZ, (sp.Eq(sp.MatAdd(*mats), Z))))
def example3():
print("There is no code for example 3")
pass
def example4():
left_op = make_left_operator(X)
Z = sp.zeros(2, 2)
for i in [0, 1]:
Z[:, i] = left_op(Y[:, i])
sp.pprint(sp.Eq(sZ, Z, evaluate=False))
def example5():
right_op = make_right_operator(Y)
Z = sp.zeros(2, 2)
for j in [0, 1]:
Z[j, :] = right_op(X[j, :])
sp.pprint(sp.Eq(sZ, Z, evaluate=False))
def main():
args = doArgs()
examples = [example1, example2, example3, example4, example5]
if args.n > 0 and args.n < 1 + len(examples):
examples[args.n - 1]()
if __name__ == "__main__":
main()