Learning Two Layer Rectified Neural Networks in Polynomial Time
Abstract
We consider the following fundamental problem in the study of neural networks: given input examples $x \in \mathbb{R}^d$ and their vector-valued labels, as defined by an underlying generative neural network, recover the weight matrices of this network. We consider two-layer networks, mapping $\mathbb{R}^d$ to $\mathbb{R}^m$, with a single hidden layer and $k$ non-linear activation units $f(\cdot)$, where $f(x) = \max \{x , 0\}$ is the ReLU activation function. Such a network is specified by two weight matrices, $\mathbf{U}^* \in \mathbb{R}^{m \times k}, \mathbf{V}^* \in \mathbb{R}^{k \times d}$, such that the label of an example $x \in \mathbb{R}^{d}$ is given by $\mathbf{U}^* f(\mathbf{V}^* x)$, where $f(\cdot)$ is applied coordinate-wise. Given $n$ samples $x^1,…,x^n \in \mathbb{R}^d$ as a matrix $\mathbf{X} \in \mathbb{R}^{d \times n}$ and the label $\mathbf{U}^* f(\mathbf{V}^* \mathbf{X})$ of the network on these samples, our goal is to recover the weight matrices $\mathbf{U}^*$ and $\mathbf{V}^*$. More generally, our labels $\mathbf{U}^* f(\mathbf{V}^* \mathbf{X})$ may be corrupted by noise, and instead we observe $\mathbf{U}^* f(\mathbf{V}^* \mathbf{X}) + \mathbf{E}$ where $\mathbf{E}$ is some noise matrix. Even in this case, we may still be interested in recovering good approximations to the weight matrices $\mathbf{U}^*$ and $\mathbf{V}^*$. In this work, we develop algorithms and hardness results under varying assumptions on the input and noise. Although the problem is NP-hard even for $k=2$, by assuming Gaussian marginals over the input $\mathbf{X}$ we are able to develop polynomial time algorithms for the approximate recovery of $\mathbf{U}^*$ and $\mathbf{V}^*$. Perhaps surprisingly, in the noiseless case our algorithms recover $\mathbf{U}^*,\mathbf{V}^*$ \textit{exactly}, i.e. with no error, in \textit{strongly} polynomial time. To the best of the our knowledge, this is the first algorithm to accomplish exact recovery for the ReLU activation function. Fo