How to Capture Higher-order Correlations? Generalizing Matrix Softmax Attention to Kronecker Computation
Abstract
In the classical transformer attention scheme, we are given three $n \times d$ size matrices $Q, K, V$ (the query, key, and value tokens), and the goal is to compute a new $n \times d$ size matrix $D^{-1} \exp(QK^\top) V$ where $D = \mathrm{diag}( \exp(QK^\top) {\bf 1}_n )$. Here, $\exp()$ is applied entry-wise and ${\bf 1}_n$ denotes a length-$n$ vector whose entries are all ones. Intuitively, attention computation captures pairwise information between words in a sentence, but not higher-order information. Indeed, recent work \cite{sht23} has shown that attention units cannot solve simple problems about detecting triples of connected words. In this work, we study a generalization of attention which captures triple-wise correlations. The generalization is based on computations involving tensors defined by tuples of words. More formally, given five $n \times d$ size matrices $Q, K_1, K_2, V_1$ and $V_2$ (generalized query, key, and value tokens), our new goal is to compute an $n \times d$ size matrix $D^{-1} \exp( Q ( K_1 \oslash K_2)^\top ) (V_1 \oslash V_2) $ where $D = \mathrm{diag}( \exp( Q ( K_1 \oslash K_2)^\top ) {\bf 1}_{n^2} )$ and $K_1 \oslash K_2 \in \mathbb{R}^{n^2 \times d}$ denotes the column-wise Kronecker product of $K_1$ and $K_2$. This generalization is indeed able to solve problems about detecting triple-wise connections that were shown to be impossible for transformers. The potential downside of this generalization is that it appears as though computations are even more difficult, since the straightforward algorithm requires cubic time in $n$. However, we show that in the bounded-entry setting (which arises in practice, and which is well-studied in both theory and practice), there is actually a near-linear time algorithm. More precisely, we show that bounded entries are both necessary and sufficient for quickly performing generalized computations: $\bullet$ On the positive side, if all entries of the input matrices are bounded above by $o(\sqrt[3]