Amortized Maximum Inner Product Search with Learned Support Functions
Abstract
Maximum inner product search (MIPS) is a crucial subroutine in machine learning, requiring identification of database vectors that align most strongly with a given query. We propose amortized MIPS: a learning-based approach that trains neural networks to directly predict MIPS solutions, amortizing the computational cost of search across queries drawn from a known distribution. Our key insight is that the MIPS value function - the maximum inner product as a function of the query - is convex (as the pointwise maximum of linear functions), and its gradient at each query equals the optimal database vector. We explore two complementary architectures: (1) Input Convex Neural Networks (ICNNs) that learn the convex value function and recover the optimal match via gradient computation, and (2) VectorICNNs that directly regress the argmax, bypassing gradient computation entirely at inference time. For ICNNs, we combine score regression with gradient matching losses; for VectorICNNs, we introduce a score consistency loss derived from Euler's theorem for homogeneous functions. We further propose homogenization wrappers that enforce positive 1-homogeneity, theoretically linking function values to gradients. Our experiments on retrieval benchmarks demonstrate that convexity provides an effective inductive bias, with learned potentials achieving high match rates while requiring only a single forward pass at inference.