High-Dimensional Conditional Density Estimation

This is the coding illustration of the first part of my job market paper.

I propose a conditional density estimator that is

Why do you need to estimate conditional density?

"What's the difference between your estimator and Izibicki and Lee (2017)?"

Some Theory to Motivate

Note: If you want to dive deep into the theory, refer to my job market paper for more details.

Suppose we want to learn the likelihood of some variable $Y$ given covariates $X$, i.e. the conditional density $f_{Y|X}$.

Assuming you have some statistics background, one way of estimating the conditional density relies on the following identity/definition:

$$ f_{Y|X}(y|x) = \frac{f_{Y,X}(y,x)}{f_{X}(x)}$$

where $f_{Y,X}$ and $f_X$ denote the joint density of $(Y,X)$ and marginal density of $X$ respectively. One popular method relies on using kernel method to estimate $f_{Y,X}$ and $f_X$ separately, and then form the ratio. However, for $X$ with even just moderate dimension, the curse of dimensionality kicks in quickly.

In contrast, my method relies on the following representation

$$ f_{Y|X}(y|x) = \sum_{j=1}^\infty E[\phi_j(Y)|X]\phi_j(Y) $$

where $\{\phi_j\}_{j=1}^\infty$ are specific functions chosen by researchers, known as orthonormal basis.

You may wonder, this is a rather daunting expression and seems more complicated than before, how does this help us?

The key insight from this expression is that, we have effectively converted the problem of estimating conditional density to a problem of estimating many conditional means, which can be estimated using any state-of-the-art machine learning estimators.

A short comment on the basis:

Main Idea of the Estimator

My estimator takes the following form (we use the hat ^ notation to denote estimators)

$$ \hat{f}_{J}(y|x) = \sum_{j=1}^J \hat{E}[\phi_j(Y)|X=x]\phi_j(y)$$

and we need to emphasize two main features:

  1. for $j=1,\cdots, J$, we need to estimate the conditional mean(s) $E[\phi_j(Y)|X]$

    • my framework allows us to use any machine learners to estimate these conditional means
  1. we also need a data-driven way of choosing the optimal series cutoff $J$

    • I leverage a cross-validation procedure to pick $J$, and a cross-fitting/averaging procedure to build the final estimator

A Simple Cross-Validation "Algorithm"

For illustration purpose, let's consider a 2-fold CV. Easily generalizable to a general $K$-fold CV.

A key contribution of my paper is showing that this estimator $\bar{f}$ is in fact optimal! (asymptotically equivalent to $\hat{f}_{J^*}$ for the best possible $J^*$).

Note: See below for a short discussion on the empirical risk we use in the cross-validation.

Empirical Risk for Cross-Validation

Moreover, unlike the usual prediction problem where we can compare the predicted value against the actual value, in my setting, the object we are trying to estimate is the unknown conditional density. So how to do cross-validation in this setting?

Suppose that we are trying to find $\hat{f}$ that minimizes the integrated squared error:

$$ E_X[\int (\hat{f}(y,X) - f_{Y|X}(y|X)^2 dy]$$

which is equivalent to minimizing

$$ R(\hat{f}) := E_X[\int (\hat{f}(y,X) - f_{Y|X}(y|X)^2 dy] - E_X[\int f_{Y|X}^2(y|X)dy]$$

This is not something that we can work with since $f_{Y|X}$ is unknown.

In the paper, we show a very convenient fact, $R(\hat{f})$ can actually be written as an expression without the unknown $f_{Y|X}$:

$$ R(\hat{f}) = E[\int \hat{f}^2(y,X) dy - 2\hat{f}(Y,X)] $$

and we can work with the sample analogue

$$ R_n(\hat{f}) = \frac{1}{n}\sum_{i=1}^n \int \hat{f}^2(y,X_i) dy - 2\hat{f}(Y_i,X_i). $$

In our case with $\hat{f}_J$, this expression can be even further simplified as

$$ R_n(\hat{f}_J) = \frac{1}{n}\sum_{i=1}^n \sum_{j=1}^J (\hat{E}[\phi_j(Y)|X_i])^2 - 2\hat{f}_J(Y_i,X_i). $$

which avoids integral altogether (a consequence of the orthonormal basis).

This is the empirical risk that we are going to use in our cross-validation procedure.

Comments on the Code:

The code is very simple and concise, and most importantly, it works! But there are some comments/future improvements that can be readily implemented.