Shortcuts

torch.sparse.mm

torch.sparse.mm(mat1, mat2)[source]

Performs a matrix multiplication of the sparse matrix mat1 and the (sparse or strided) matrix mat2. Similar to torch.mm(), If mat1 is a (n×m)(n \times m) tensor, mat2 is a (m×p)(m \times p) tensor, out will be a (n×p)(n \times p) tensor. mat1 need to have sparse_dim = 2. This function also supports backward for both matrices. Note that the gradients of mat1 is a coalesced sparse tensor.

Parameters
  • mat1 (SparseTensor) – the first sparse matrix to be multiplied

  • mat2 (Tensor) – the second matrix to be multiplied, which could be sparse or dense

Shape:

The format of the output tensor of this function follows: - sparse x sparse -> sparse - sparse x dense -> dense

Example:

>>> a = torch.randn(2, 3).to_sparse().requires_grad_(True)
>>> a
tensor(indices=tensor([[0, 0, 0, 1, 1, 1],
                       [0, 1, 2, 0, 1, 2]]),
       values=tensor([ 1.5901,  0.0183, -0.6146,  1.8061, -0.0112,  0.6302]),
       size=(2, 3), nnz=6, layout=torch.sparse_coo, requires_grad=True)

>>> b = torch.randn(3, 2, requires_grad=True)
>>> b
tensor([[-0.6479,  0.7874],
        [-1.2056,  0.5641],
        [-1.1716, -0.9923]], requires_grad=True)

>>> y = torch.sparse.mm(a, b)
>>> y
tensor([[-0.3323,  1.8723],
        [-1.8951,  0.7904]], grad_fn=<SparseAddmmBackward>)
>>> y.sum().backward()
>>> a.grad
tensor(indices=tensor([[0, 0, 0, 1, 1, 1],
                       [0, 1, 2, 0, 1, 2]]),
       values=tensor([ 0.1394, -0.6415, -2.1639,  0.1394, -0.6415, -2.1639]),
       size=(2, 3), nnz=6, layout=torch.sparse_coo)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources