diff --git a/.pyre_configuration b/.pyre_configuration index 27941fe8a..be0c1df52 100644 --- a/.pyre_configuration +++ b/.pyre_configuration @@ -1,7 +1,12 @@ { "site_package_search_strategy": "pep561", "source_directories": [ - {"import_root": ".", "source": "gpytorch/"} + "gpytorch/" + ], + "ignore_all_errors": [ + "gpytorch/functions/*.py", + "gpytorch/lazy/*.py", + "gpytorch/test/*.py" ], "search_path": [ ".", diff --git a/docs/source/conf.py b/docs/source/conf.py index f669ce6e8..e132ace1d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -51,6 +51,8 @@ def find_version(*file_paths): shutil.rmtree(examples_dest) os.mkdir(examples_dest) +# Include examples in documentation +# This adds a lot of time to the doc buiod; to bypass use the environment variable SKIP_EXAMPLES=true for root, dirs, files in os.walk(examples_source): for dr in dirs: os.mkdir(os.path.join(root.replace(examples_source, examples_dest), dr)) @@ -58,7 +60,19 @@ def find_version(*file_paths): if os.path.splitext(fil)[1] in [".ipynb", ".md", ".rst"]: source_filename = os.path.join(root, fil) dest_filename = source_filename.replace(examples_source, examples_dest) - shutil.copyfile(source_filename, dest_filename) + + # If we're skipping examples, put a dummy file in place + if os.getenv("SKIP_EXAMPLES"): + if dest_filename.endswith("index.rst"): + shutil.copyfile(source_filename, dest_filename) + else: + with open(os.path.splitext(dest_filename)[0] + ".rst", "w") as f: + basename = os.path.splitext(os.path.basename(dest_filename))[0] + f.write(f"{basename}\n" + "=" * 80) + + # Otherwise, copy over the real example files + else: + shutil.copyfile(source_filename, dest_filename) # -- Project information ----------------------------------------------------- @@ -282,6 +296,10 @@ def _process(annotation, config): arg = annotation.__args__[0] res = "list(" + _process(arg, config) + ")" + # Convert any List[*A*] into "list(*A*)" + elif str(annotation).startswith("typing.Dict"): + res = str(annotation) + # Convert any Iterable[*A*] into "iterable(*A*)" elif str(annotation).startswith("typing.Iterable"): arg = annotation.__args__[0] diff --git a/docs/source/likelihoods.rst b/docs/source/likelihoods.rst index 2607d5ff2..a4c215a6e 100644 --- a/docs/source/likelihoods.rst +++ b/docs/source/likelihoods.rst @@ -12,6 +12,7 @@ Likelihood -------------------- .. autoclass:: Likelihood + :special-members: __call__ :members: diff --git a/examples/04_Variational_and_Approximate_GPs/Non_Gaussian_Likelihoods.ipynb b/examples/04_Variational_and_Approximate_GPs/Non_Gaussian_Likelihoods.ipynb index fc7604833..fb9b3667b 100644 --- a/examples/04_Variational_and_Approximate_GPs/Non_Gaussian_Likelihoods.ipynb +++ b/examples/04_Variational_and_Approximate_GPs/Non_Gaussian_Likelihoods.ipynb @@ -131,25 +131,25 @@ "output_type": "stream", "text": [ "Iter 1/50 - Loss: 0.908\n", - "Iter 2/50 - Loss: 4.296\n", - "Iter 3/50 - Loss: 8.896\n", - "Iter 4/50 - Loss: 3.563\n", - "Iter 5/50 - Loss: 5.978\n", - "Iter 6/50 - Loss: 6.632\n", - "Iter 7/50 - Loss: 6.222\n", + "Iter 2/50 - Loss: 4.272\n", + "Iter 3/50 - Loss: 8.886\n", + "Iter 4/50 - Loss: 3.560\n", + "Iter 5/50 - Loss: 5.968\n", + "Iter 6/50 - Loss: 6.614\n", + "Iter 7/50 - Loss: 6.212\n", "Iter 8/50 - Loss: 4.975\n", - "Iter 9/50 - Loss: 3.972\n", - "Iter 10/50 - Loss: 3.593\n", - "Iter 11/50 - Loss: 3.329\n", - "Iter 12/50 - Loss: 2.798\n", - "Iter 13/50 - Loss: 2.329\n", + "Iter 9/50 - Loss: 3.976\n", + "Iter 10/50 - Loss: 3.596\n", + "Iter 11/50 - Loss: 3.327\n", + "Iter 12/50 - Loss: 2.791\n", + "Iter 13/50 - Loss: 2.325\n", "Iter 14/50 - Loss: 2.140\n", "Iter 15/50 - Loss: 1.879\n", - "Iter 16/50 - Loss: 1.661\n", + "Iter 16/50 - Loss: 1.659\n", "Iter 17/50 - Loss: 1.533\n", "Iter 18/50 - Loss: 1.510\n", "Iter 19/50 - Loss: 1.514\n", - "Iter 20/50 - Loss: 1.504\n", + "Iter 20/50 - Loss: 1.503\n", "Iter 21/50 - Loss: 1.499\n", "Iter 22/50 - Loss: 1.500\n", "Iter 23/50 - Loss: 1.499\n", @@ -175,10 +175,10 @@ "Iter 43/50 - Loss: 1.011\n", "Iter 44/50 - Loss: 0.991\n", "Iter 45/50 - Loss: 0.974\n", - "Iter 46/50 - Loss: 0.959\n", - "Iter 47/50 - Loss: 0.945\n", - "Iter 48/50 - Loss: 0.932\n", - "Iter 49/50 - Loss: 0.919\n", + "Iter 46/50 - Loss: 0.958\n", + "Iter 47/50 - Loss: 0.944\n", + "Iter 48/50 - Loss: 0.931\n", + "Iter 49/50 - Loss: 0.918\n", "Iter 50/50 - Loss: 0.906\n" ] } @@ -241,7 +241,7 @@ "outputs": [ { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQsAAADDCAYAAACVmTQ/AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAbPklEQVR4nO3dfXxU1bno8d/khSQEQ6AkvIgIYrGIwjn2OR+QgwjK4VCgKiiWT6tePli9ChSpkBC9YMAASYH2HDge0Sj39mhrCx5qK6U9wMXegry68ChQa5UohEJiBEyjvCQhmfvHzJBJmMnssCczezbP95/M7Nmz9pM1M8/ee+211/J4vV6UUiqSpHgHoJRKDJoslFKWaLJQSlmiyUIpZYkmC6WUJZoslFKWpNgtQET6A0uAd4HewCljzLMt1kkHVgLHga8DJcaYj+xuWykVO9E4sugK/NIYs8IY8wQwVUS+2WKdOUC5MaYY+BdgbRS2q5SKIdvJwhjzjjHmNy3KPNNitQnAbv/6B4EhIpJld9tKqdixfRoSTEQmAZuNMR+2eCkX+DLoeY1/WU24sgoKCrRrqVJxUlJS4mm5LGrJQkRGA6PxnXK0VAVcFfQ8y7+sVYsXL4643aqqKnJzcy1GGXtOjw+cH6PT4wN3xVhYWBhyeVSShYhMAG4DngB6isi1wF+AC8aYGmATcCuwQ0RuBt73L1dKJYhoXA35JrAOMMAfgEzg34FJwGmgBFgFrBSRBcD1wMN2t6uUii3bycIYsx/oFGGdc8BMu9tS7nfhwgUqKyupra0lcEd0Y2MjNTXOPhBNxBg9Hg9paWn06NGDlJTIqSCqDZxK2VVZWUlmZiZXX301Ho+vja2+vp7U1NQ4R9a6RIzR6/VSXV1NZWUlvXv3jvh+7cGpHKW2tpbs7OyLiUK1H4/HQ3Z2NrW1tZbW12ShHMXr9WqiiCGPx4PVAbA0WaiEVVFRwZgxY6isrLzsMvbv309eXh5FRUU8+eSTlJWVAbB8+XK+//3vRytUy86cOcPUqVN55ZVXmi3/wx/+wIABA8jLy6OwsJCnn36as2fPhi3nzTffpLq6OqqxabJQCau4uJhdu3axbNmyy3p/TU0NS5YsoaSkhIULF1JUVMTMmTNpaGjg/vvvj3K01mRmZjJ+/PhLlo8ePZprr72WmTNnsnjxYkaMGEF+fn7YctojWWgDp0o42dnZnD9//uLz0tJSSktLSU9Pb9MP5Pe//z3Dhg0jOTkZ8P1Q+/Xrx759++jZsyfHjh2jpKSE/fv3s3DhQvr378/s2bMZOHAgZWVlrFmzhl//+tfs3r2b+vp67rnnHtLT05k2bRrf+c532L17Nw899BDPPvssGzZsoLy8nNWrV7NhwwYWLFhAr169KC8vZ+nSpTQ2NjJjxgxuuOEGjh07xm233dZq7OPHj2f+/Pk0NjayfPly6uvrKS8vZ968eaSkpHDgwAGee+45vvWtb9HQ0MDvfvc7kpKSGDFiBJMnT76setdkoRLOn//8ZwoKCnjzzTc5d+4cGRkZ3H333ZSUlLSpnL/+9a987Wtfa7asW7duHD9+nJ49ewJQUFDA4cOHyc/Pp6ioiC+//JInnniC9957j4aGBpYsWcI777zDmTNnuOuuu3jrrbcYPnw4t99+O3PmzOGrr76iV69e9O3blxMnTrBo0SK2bt1KZmYm8+bN47XXXuPVV1+ltraWO++8k+nTp1NUVGQp/i5dunDy5EmGDRvGqFGjOHDgAC+//DIrVqxg8ODBzJo1i759+7J3715KSkpISkri7rvv1mShrhw9e/YkKyuL2tpa0tPTqa2tJSsrix49erSpnN69e3PkyJFmyz7//HOuueaai68D9OnTh7KyMgYNGsSwYcMYPXo048ePp0+fPtTU1LBy5UoaGhro3r37xXL69u1L586d6dy5M3fddRcbN27k8OHDLFiwgJUrV/Lxxx+zYsUKqqurufrqqykrK+Oee+4BuLj9SKqrq+nWrRs1NTU888wzNDQ08MUXX1yyXmZmJs8++yxZWVmcPn26TXUUTNssVEKqqqrikUceYfv27TzyyCN89tlnbS5j3Lhx7N69m4aGBsDXuHjs2DFEBIDjx48DUF5eznXXXcenn37K5MmT2b59O9u2bcPj8dC9e3fy8vKYO3cu3/3ud0Nu5/777+eVV14hIyMDgBtuuIHBgweTl5fHnDlzuOWWW+jfvz9Hjx4FfEc8kWzevJlRo0aRlJREXl4eixYtYvr06RdfT05Oxuv18uGHH1JYWMjUqVOZO3cuHTt2bHM9BeiRhUpI69atu/h41apVl1VG586dWbhwIQUFBWRlZfG3v/2N1atXk5yczPr16zl37hxFRUW8//77FBYWUldXR3FxMf369WPw4MHk5uaSn59PQUEBycnJDB48mLKyMg4dOkRpaenFhtdevXrR2NjIhAkTAJgwYQI7d+6kuLiYiooK5s+fzze+8Q0ef/xxjh07xscff8yhQ4e49957yczMBGD79u2Ul5fzwgsvkJmZyblz5y6edo0bN47Zs2eTnZ3NoUOHKCsrY/jw4axYsYKBAwcyadIkFi9ezLBhwzh+/DhvvfUWd9xxR5vry+PUSYYKCgq8etdpbDgpxsOHD3P99dc3W5aIvSOdKFyMLeu8sLAw5C3qehqilLJEk4VSyhJNFkopSzRZKKUs0WShlLJEk4W6Ym3ZsoUBAwbw4osvNls+duxYpkyZYusGNTfSZKGuWGPHjmXs2LG8/PLLF2/T/uCDDzh9+jTf/va329wj1O20U5ZypPT0tKBnaWHXs+L8+fCDu3Tv3p2BAweybds2xowZw8aNG5k4cSIAO3fu5PXXX6dTp04MGTKEKVOmMG/ePHJycvjoo48oLi7myJEjTJs2jQceeIADBw4wdOhQfvjDH9qK16miNbp3D3xTGA4xxvxDiNdHAf8KBG4J3GSMWRGNbStl14wZM1i+fDkiQlZWFnV1dQDk5+ezefNmOnXqxMiRI7nvvvuYOHEio0aNYuPGjaxfv55Zs2YxfPhwhg4dylNPPcXo0aM1WUQwAvgN8HetrDPHGPP/orQ95XLBRwPt3Tty2LBhnDx5kpKSEvLy8nj++ecB3z0aa9asAXw3ldXU1FBWVsbbb7/NiRMnLt6ZCr6bzTwej+N7cdoRlWRhjPlP/9FDax4U3x06WcBLxphj0di2UtHw2GOPsWPHjma3rF933XXMnj2btLQ0fvvb33LhwgVKS0vZu3cv27ZtY9euXRfXvRKGAoxVm8UHQJEx5oiIDAK2isiNxpjG1t5UVRVx0rKojwYUbU6PD5wVY2NjI/X19c2WBe4Kjba3336bHTt2sG3bNqZMmcKUKVP46KOP2LFjBwcPHmTx4sXk5+eTnZ1NTk4OY8eOpW/fvsydO5czZ87wySefsGfPHg4ePMgbb7zBLbfcQnl5OVu3bmXUqFHtErMd4eqxsbHR0m8tJsnCGFMV9PhPIpINXAMcbe19Vm9ucspNUOE4PT5wTow1NTUhD+Xb4/B+9OjRjB49utmyAQMGsGXLlovPR44c2ez14LtdA/bu3XvxVOkvf/lL1OOMplD1mJSUZOnzb7dLpyKSKSI5/scFItLV/7gr0AFo+wAESqm4idbVkNuBB/HNc7oA+DEwDbgZeAw4AqwSkQ+AG4EHjTHnQ5emlHKiaDVw/hH4Y4vF/x70+i+BX0ZjW8rdAvNYXAkNhk7QlrrWHpzKUdLS0qiurrY88Y26fIHpC9PSrHV60x6cylF69OhBZWUlp0+fbjYxclKSs/driRhj8MTIVmiyUI6SkpJyySS9Thr2L5wrIUZnp0KllGNoslBKWaLJQilliSYLpZQlmiyUUpZoslBKWaLJQilliSYLpZQlmiyUUpZoslBKWaLJQilliSYLpZQlmiyUUpZoslBKWaLJQilliSYLpZQlCZ0sKioq2mW264qKCsaMGaOzaPu1V31oPTfXnvURjd9KVJKFiPQQkZdF5J0wryeJSImILBCRl0RkWDS2W1xczL59+1i2bFk0imtW7q5du6JebqJqr/rQem6uPesjGr8VTzQGRhWR+4BaoNAYIyFenwqMNMbM8M8bsgcYaIwJO9VUQUGBd/HixSFfy87O5vz54cClE74ETz/XVqdOnQr7Wrhyr7/ey+bN9aSnh36fE4ZbO3MGxo1L5ZNPQo/iHG78yMupDyvaWm64+DIyoLS0njvuiP/gvq19zkuXJrNmTTLhfmrtVc+Xlr0OmAVAenp62JnoCgsLKSkpueTLEqu5TicAW/zrnhaR88Ag4EBr5YabUm3Hjh384Aeb2LOn2yWvtVLvFlxaXqRyT53ysGPHFwwZUh/ydSdMDbh/fwfeead7K2skh1ne9vqwpq3lhosP1q+v5aab4l/HrX3Or77ag5MnWxtuv73quWXZnUhPT2fcuHEsWLDA0pSFwWI1YG8u8GXQ8xr/stbfFCZT5+bmcuONa9m7N5fU1FTq6+t58MEHWLas2HagTz31FD/72c/o0KEDdXV1rZY7eXIq+/YlkZHRldzc8Hu3eB9ZpKX5vqi33trI669fmtROnjxJt26hv7BtqY+2aEu5oeL7+c+TmT8/BY+nI7m5HWzHEw3hPue6Ot/PbN++Onr1Cv09aa96Di47NbWRuro6cnNzuemmm9pcTqySRRVwVdDzLP+yy3bqVAWPPjqZSZMm8cYbb1BZ+Qlhvu9tUlPzCY8+OpmHH36YtWvXtlruVf7/6Nw5+9ttT+fP+5JF587ekP9LY2Nj2P+xLfXRFm0pN1R8OTm+H53T6x6aYrz6ai/hziraq56Dy276rVxeI2e7JQsRyQQ6GmM+BzYBI4FX/W0W6cCf7JQfmKC2qqoqqjNWB098u2rVqlbXzcgIfGE9QPzPm8M5e9b3NyOj7e9tS33EstzA/5IIycJK/bdXPQeXbfe3Equ5TtcDfy8ihUAf4KHWGjcTRaJ8YQPxXU6ycKqOHX1/fYnauRobobbWF2O4RvBEEau5ThuB+dHYlpMEvrCBPYdTBZJFIF43CBzVJUrdZ2R4cfiEZRElePjxlThHFr49W+AH5gaJU/e+v244qtNkYUPHjsFtFs5lp83CqZpOQ+IbRyRuqntNFjYEzkET51A4vnFEU3p64DTE2Yk6cCXKDUd1mixsCOzdzp+PbxyRuLHNIlHqPrAjcUPda7KwIbCndvreLRCfG/ZuAU11H984ItHTEAUE97OIcyARuPE0JFHaLNxU95osbEiUvZsbT0NSUyEpyUt9vYcLF+IdTXiBxu9AY3gi02RhQ6KcN7tp7xbg8STG0YWb6l6ThQ1NHYO0zSIeEuHITtssFJAYX1ZoOvJxwxc2WCLUv5vqXpOFDYlyGuKmy3fBAu0Agb4MThQ4qtM2iytcIuzZwF3nzcESof4DdZ/oN5GBJgtbmt+i7lxuvDcEEuP+EDcd1WmysCERvqzgri9ssESofzddttZkYUMiXLoD956GJMKNfIHYAveyJDJNFjYEnzNHYZD0dtHQ4J7BV1pKhBv53HRUp8nChtRUSEnx0tDgoT704N5x13TpzovHuTvgy5IIgw/ppVN1kdPPm920Z2spUPdOvnTddOk0zoFEgSYLm5y+d3PTpbuWAm0WTu5B21T/Dj1PbYNoDdg7BpiMb3h/rzFmcYvXp+EbuDewD1hrjHk1GtuON6cfWbjpRqaWnF734K6rIbaThYh0BF4ABhljakVkg4jcaYzZ1mLVqcaYI3a35zS+vgsex04H4KYva0uaLGIrGkcWtwJHjTG1/uc78U1X2DJZzBKRSqAj8Jwx5nSkgq1Mrxbv6QFTU7sDHThx4gtycuoueT3e8R0/3gHoTkpKfdj6jHeMkYSLr6GhE9CF06fPUVUV3/8hXIxfftkTSOLs2VNUVcX3Xnq7n3M0koWVqQn/CGwyxnwuIuOB14E7IxZscdq/eE4PeNVVvipMS+sSdgrDeMYXmLowKyu11TjiPcViJKHiy8nxNbl5vc6YwjBUjHV1vnlae/fuihOq2M7nHI1kEXFqQmPMp0FP3wLeFJFkN0w01HyyG+edhrhpwNiWnN64DO7qEBeNqyG7gWtFJM3//B+BTSLSVUSyAESkWEQCienrwKduSBTg/KH19NJpfLmp/m0nC2PMWeBxYLWILAEO+Bs3C4AZ/tUqgTUi8jTwNL6pDl3B6Xc+6qXT+LlwAerrPXg8XjrE/yzJtmhNX7gV2NpiWX7Q4+jO9OogTm+Rd1NrfEuBBJgIde+G3rPaKcsmp9/M5OZ+Fk6/kc9tiVqThU1OP7IInB658TRE6z62NFnYlChtFm7ZuwVz+oDJgStRbjmq02RhU6IcCrvh0l1LTq97N43sDZosbHP60HpuHVIPnH8a4rZErcnCJqefhrht7xbM6YMPua3uNVnY5PSOQW5uswgMPtTY6MzBh9x2JUqThU1O73Lc9IWNcyDtxMn1r6chqpnAoCZObbNounznjr1bS05ut9DTENWMk/ds0HR65NYjCye3Gbmt7jVZ2OTkPRu4b+/WkpOvRrltQmpNFjY5/Vq/my+dgrPrX9ssVDNO70Xoti9sS04+DXFb3WuysEkvncZXU/07L1m7re41Wdjk9AZONw2+EkrTmBZxDiQEbbNQzQQ3cDqtF2Fg8JWkJC+pqfGOpn04eUwLtzUua7KwKTkZOnTw4vV6qK2NvH4suW3wlVCcfGSnl07VJZzayOa2PVsoeuk0djRZREFbLt9VVFQwZswYKisroxpDqHLd1hofSri6b696bkvZbqv/qCQLERkjIs+LyCIRKQzxerqIPCciT4nI/xaRAdHYrlM0nTdH3rsVFxeza9culi1bFtUYQpXr9j4W0FT3LY/q2que21K225JFrKYvnAOUG2OWi8jNwFrgNrvbdgpfi7yHpUuT6dat+Q/z7NlsOnZM5vnn19DQcAH4BrCS0lIoLX2e5OQUZsx4/LK33Vq59903wx/fZRfveIH/bcuWJM6cab0+7NQztF72tGnT6Ngxudn6R4+66ya+WE1fOAHfFAAYYw6KyBARyTLG1LRWcCJMXwjQpUsOkM4vfpEc4tXA/Es/CPnehgb4t3+zs/Xw5a5b53ucnV1LVdXJsCU4oQ5b01p8GRkdga/x7rtJvPtuEu1Xz7Ra9tq1od/h8Xjxej+nqqrR7sZtS5TpC8Ot02qySITpCwFeegk2bbpAY4jvw1dffUmnTr6E8atf/Yq9e/eQnJxCQ8MFhg4dxuTJk21vv7Vyk5Jg4sSkiHUU7zqMJFx806dDZmY91dVNp4DtVc+tlR38OQcbOLCRQYO6RWXb0eD46QstrpOw+vWDWbNCT7BWVfUVubm+49CdO1/j0Ud78PDDD7N27VoqK19j9uy7bW+/vcpNBGlp8NBDzbN0e9ZHuLKDP2e3ikayuDh9of9U5B+B50WkK3DBf6qxCd/pyg5/m8X7kU5B3Ghd4LwAWLUqevMutVe5iao96+NKrutYTV+4Cl9CWQDMBR62u12lVGzFavrCc8DMaGxLKRUf2ilLKWWJJgullCWaLJRSlmiyUEpZoslCKWWJJgullCWaLJRSlmiyUEpZoslCKWWJJgullCWaLJRSlmiyUEpZoslCKWWJJgullCWaLJRSlmiyUEpZoslCKWWJJgullCWaLJRSltgag9M/gncJ8AnwdeBpY8xnIdY7AhzxPz1ujPmene0qpWLP7oC9y4D/a4xZLyLfBlYCD4ZY76fGmEU2t6WUiiO7yWICsNT/eCfwH2HWGyki+fgmGvq9MWaXze0qpWIsYrIQkc1A9xAvPUPzaQlrgC4ikmKMudBi3QJjzD7/JMrvishEY8zhSNtOlLlOW+P0+MD5MTo9PrgyYoyYLIwx/xzuNREJTEtYjW9Kwi9CJAqMMfv8f8+KyHv4Zi2LmCwSZa7TSJweHzg/RqfHB+6P0e7VkMC0hOBLAJsARCRJRPr4H98pIuOC3nM9UGZzu0qpGLPbZvE08CMRGQD0B+b5lw8GXgVuxjcB8iIRuQXoBWwwxrxtc7tKqRizlSyMMaeBR0Isfw9fosAYcxC41852lFLxp52ylFKWaLJQSlmiyUIpZYkmC6WUJZoslFKWaLJQSlmiyUIpZYkmC6WUJZoslFKWaLJQSlmiyUIpZYkmC6WUJZoslFKWaLJQSlmiyUIpZYkmC6WUJZoslFKWaLJQSlmiyUIpZYnd6QuT8I3BWQTcYYw5FGa9McBkfIP3eo0xi+1sVykVe3aPLIYAe4Gz4VbwTyz0AvBD/xSGg0XkTpvbVUrFmK1kYYz5b/9I3q25FThqjKn1P9+Jb9pDpVQCsTV9oTHmTQvbCJ7iEHzTHFqaFqmwsNDKakqpGLA1faFFgSkOA7L8y1pVUlLisbldpVQUtdvVEBHp53+4G7hWRNL8zy9Oc6iUShwer9d72W8WkS7ATGAuvukKXzPG7BGRHOA9oL8x5ryI/BNwH/A5UK9XQ5RKPLaShVLqyqGdspRSlmiyUEpZYqsHZ6xE6gEqIunASuA48HWgxBjzkcNinA/0ACqBb+K79Pyhk2IMWu97wM+Aq4wxXzklPhHxAD/wP+0LZBtjpscqPosx9sP3XXwH+Dt87XhWuhhEK74ewBJgiDHmH0K8ngQsA74CrgXWGmP2WCnb8UcWFnuAzgHKjTHFwL8Aax0YYyfgSWPMj4ANwAoHxoiIDARujGVs/u1aie8BoNoYs9oY8yTwrw6MMR942xhTAvwI+HEsYwRGAL8BwnU9uB/IMsYsAeYDr4hIspWCHZ8ssNYDdAK+S7QYYw4CQ0QkK3YhRo7RGLPQGBNoTU7Cl9ljKWKM/h9DPhCPq1VWPufvAV1FZLaIBPaOsWQlxs+AHP/jHGB/jGIDwBjznzTvBNlS8G/lNHAeGGSl7ERIFlZ6gF52L9Eosbx9EekA/A9gQQziCmYlxqVAkTGmLmZRNbES37X49oqrgZ8C/2V1rxglVmL8CTBURH4CPAP8nxjFZtVl/1YSIVlY6QF6Wb1Eo8jS9v2JYg3wv4wxZTGKLaDVGEXkGqALcL+IFPgXPyki4oT4/Grw3biIv00qC7gmJtH5WInxp8DL/tOkScA6Eekam/AsuezfSiIki5A9QEWka9CpxiZ8h4iIyM3A+8aYGifFKCIZwIvAT4wx+0Xk3hjGFzFGY8wxY8w0Y0yJ/3wbf6zGCfH5l20DrgPwL0vG12AcK1ZivAao8D/+Amgkzr8zEcn0d5SE5r+VrkA68Ccr5SREp6xQPUBFZDlw2hhT4v8hrsT3IV0PLIvD1ZBIMf4KuAk44X9LZqjW6njG6F8nB/if+MYoKQJeNMYcd0J8ItIZWA4cBfoDG4wxv4tFbG2IcQS+Bvd3gX7AfmPMCzGM73bgIWAcvqPYHwPTgZuNMY/5r4YU4xtWog/wktWrIQmRLJRS8ZcIpyFKKQfQZKGUskSThVLKEk0WSilLNFkopSzRZKGUskSThVLKkv8P9WWPNod/4y4AAAAASUVORK5CYII=\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQUAAADBCAYAAADGmKWHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAUMUlEQVR4nO3dX2gbV74H8K/s+G+7kew0brqXLmu1W1oI1H8mC2Vfiq2+ZBv2Ytyk5KHsrf8Eml32waR116YmEOLkgh9aWi7R2tuHgtl2jaFstlBu7IWUpnR3EnWhS7ttI+WyC3EU6siBJLZkS/dBM4qOPSOPNJLmzPj7gRB7JB3/fCT/5syZM/PzZTIZEBHpapwOgIjkwqRARAImBSISMCkQkYBJgYgETApEJNhV6gsVRQkACAJQACyrqjqXtz0EIAEAqqpesBskEVWPnZHCYQAJVVXDAM7mbR8GcEVLBsfsBEdE1VfySEFLBvrI4EreQwcAhLWvA4XaGB0d5copIgedOXPGt3lbyUkhz+sAhkp98cmTJ7d9TjweR1tbW6k/oioYo32yxwfIH2Mx8U1MTBhutzXRqChKP4BJAK15m/+W933CTvtEVH12JhpDyI4SjmibXtCSRBjAYUVRogDO2Q+RiKrJzpzCBQDdm7bNaV+Gt76CiNygHHMKRGVx9+5dXL16FTJfuZtOp3H79m2nwzBlFJ/P54Pf78dDDz1kqQ0mBZJGMpnEY489htraWqdDMZVKpVBXV+d0GKaM4tvY2MC1a9csJwWuaCSpyJwQ3Kq2trao0ReTArnO9evXEQqFsLS0VHIbkUgE8/PzWFhYwPT0NKLRKABgfn4eY2Nj5QrVVCKRwMGDBw3jeuqpp7CwsICFhQVMTU0hkUhUPJ58TArkOpOTk7h06RJOnz5d0usTiQSmp6fR19eH3t5eDA4OYnx8HADQ09NTzlBNBQIBtLe3b9ne2dmJ9vZ29Pb2ore3FyMjIzh69KhhG4lEAlNTU2WPjXMK5BqBQACrq6u578PhMMLhMBobG4vam87NzaGzs1PY1tLSgkgkgvb2dkQiEUQiESwuLmJgYACXL19Ga2srFhcXcejQIVy8eBGtra3o6OjAF198gbm5OXR3d6O9vR1zc3OYnZ3F8ePHMTIyAgBYXFzMPb+1tTX3/FgsZvn3TiQSWF5exuLiIlZWVnJxXb58GZFIBH6/H4uLi1heXsbw8DACgYDl/tiMIwVyja+++gpHjhxBU1MTAKCpqQkvvvgivv7667L+nM7OTnR2dqKnpwczMzNYXFzE3NwcOjo68MYbb6C9vR3t7e2YmZlBT08PWlpaMDIygr6+vlwbfX19CAaDGB8fF54/NjaGnp6e3IjAquXlZQSDQQwODqKjowMzMzO5RNTZ2bnlMTs4UiDXeOSRR7B7926sra2hsbERa2tr2L17N/bt21dUO/39/XjllVcwODiY2xaLxdDZ2Wk44hgYGAAAjI+PI5lMoru7G4FAIPf8/L1yT08Ppqamcq8BIDz/+PHjaG3NLvi9deuWpXgTiQSCwSAWFhYQi8XQ3S0sD0I0GkUsFkMsFsPTTz9ttRtMMSmQq8TjcQwNDWFgYAAzMzMlTTYGAgGcOHEC09PTaG9vRywWw9tvvy08Rz98GBkZwdTUFDo6OtDf34/9+/fnhv9Adg8ei8VyyaG/vx9jY2O5RHHq1Cnh+SMjI8LhQyQSEQ5l9D/whYUFABBiW1lZQWtrK2KxGKLRKJaXl5FIJBCLxXKPXbt2DdFoFNFoFMFgsOi+AQCfkwtFRkdHM7wgqnpkj/Gbb77BE0884XQYBblxnQIAfPfdd3j88ceFbRMTE4ZXSXJOgYgETApEJGBSICIBkwIRCZgUiEjApEA7TjQaxTPPPINIJFJw207FdQokncbGhrK1tbq6tmVbMBjMrVN45513AGTXAOirA3c6jhRoR/L7/aaPRaNRTE9PY35+PrcQaHp6OnfF4sLCAg4ePIhIJFKVKyqrjUmBpLO6ula2f4X09fVhfn5+y6rCzdcr5F9X8O6776K3t1dY5uw1PHygHau3txdHjx7FiRMntjyWf72C0TUHLS0t1Qy1qmwlBe3uzQdUVX0tb1sAwGUAFwCcU1X1isnLiRwRjUZzVxl2d3fD7/cjEonkrinYfL1C/jUHsVgM58+fz13vYHT9gtvZSgqqqs4pivKcwUPPqaoatdM2UaUEg0HMzs4CQO6eBwDw2WefCc/R5f/BHzp0CHV1dXj++ecBAB999FGlw626Sh0+dCmK0oVsrUkWmCVykbInBVVVEwD0CtR/RPYwwlQ8Ht+2TTdM5jBG+9LpNFKplNNhFLSxseF0CAWZxZdOpy39rQEVSAqKogwD+EBLDtte0G31Ul6ZL/nVMUZ7EokEampqpL+js8yXTgNb49vY2EBtba3l974cE41BRVG6VFW9on3/gbYtCOC1wi0Q3VdfX49r165JXwympkbeM/lG8enFYKyyPdEI7VAh73sgW5qeZx2oKM3NzVKPZAD5b1RTjvjkTXlE5AgmBSISMCkQkYBJgYgETApEJGBSICIBkwIRCZgUiEjApEBEAiYFIhIwKRCRgEmBiARMCkQkYFIgIgGTAhEJmBSISMCkQEQCJgUiEjApEJGASYGIBLaSgqIo/YqinN20LaBtDymKErIXHnD9+nW88MILWFpastuUYduhUKgibbtRpfqD/SyqZD+X42+lEmXjhgHMqaoatVIMZjuvvvouPv98L44d+xOOHTtmp6ktzp07j08/bbHUdldXGvv2lfXHl106DVy65MPt2z7DxxOJRgQC5vuBYvqjGFbbNYtv374Murrkve277upV4J//3H4/W8l+/vzzvfjtb/8Hv//9yZLbqUTZuAMAwtrXgVIbCQQCWF1dBTAJ4E/4+GPg44/LEJ3gVwB+ZantJ59M44sv5K5e9Ic/1ODllwsVKtm7TQvW+6M4Vts1j+/TT5Po7pY3Mdy5A/z0p/W4c8c4IYsq28+zs/+F2dlGNDY2llQVzPFS9GalrD755BOcOnUKf/7zVayvn0dNTQ327m3DT37yOOrrG2z9zGRyDd9++y1u3ryZK55h1nYmA/zlL02IxQqXuJOhJNs//rEbgB8//nEKweD6lsfX11PYtWtr0iimP4pRbLtG8X35ZT3i8Vr8/e+38eij90qOpVzM3ud//7sWd+78EA0NafzsZ2uGz6lWP9fVxfHzn/8nxsfHLZeKy1eJpPA3AK0AEtq/gswKV7S1taGtrQ3p9AwaGt5DKpXCL34xiLfeeqsMITbh17/+HWZmZtDYWI9kMmnadjoNNDcDa2s1eOihNhQqDuR0kZCammy5tV/+0ofR0a2BxuMrJjFa74/iFNeuUXwvvwzMzgL19X60tf3AZjzlYdSHt25lRwg/+hHw0UdmH5LK93NDQx1SqRTa2gaxf//+klqzPdEIrWxc3vdhAPok4zk77cfjcQwNDeHDDz/E0NAQbty4Yac5w7YvXrxYsO2aGqCxMTtsXV0t24+viHvajrSpqfjXWu2Parer/y73nB8kFHT3bvb/7fq+0v1cjr+VSpWNCxu/ojjvv/8+gOwv/Oyzz5ajyS1tA8Cbb75Z8LnNzdmEcO9e9mtZ3buX3Vs1Nxd/7F1Mf1SzXf130X83WelJa7vPR6X7uRx/K1ynYIGe/fW9gays7q3chH1ffUwKFjQ1uWNvpR/eeOGDqXNLUlhdzX429M+KmzEpWKAPCd1yXCvzIU6x9N9F9vkcL/U9k4IFbtlb3b2b3VvpE6NeoO959d9NVvpno7HR2TjKgUnBArfMgOt7Uy/srXTs++pjUrDALXMKXprs0rklKegjGc4p7BBumVO4f0rS4UDKyD19n/3fCwmZScEC/ThR9jmF+x9M9++tdPr8iOxzClbXKbgBk4IF3Fs5h31ffUwKFrhtTsELeysd5xSqj0nBAjecktzYAJJJH3y+DBrsXUQqFT3Bydz3AEcKO44bhrD5H0qf3AOaouh7Xn3FoKw4p7DDuGEI68XTkYA7RmlA/uGDw4GUAZOCBW6YU/DS8DWfG0ZpgLfO/DApWOCGvZWdy6Zlpp8OXl31IZ12NpZCePiww7hhb+XVkYLPlz9ScziYArzU/0wKFnBOwVnu6H+ektxR3HClnpf2VJu54fDNS/eyYFKwQH+jZb6m36tzCoA7Tkt6aeEYk4IFblhA4+WRAvu/upgULHDDzUO9PKcg+0Tv+jqQSvlQU5NBfb3T0dhX8t2cFUUJAAhBq+2gquqFvO2XkS0Xd05V1St2g3SaflpM1g8l4K2bfGwm+1Wq+acjvbCa1M4t3gvVjHxOVdWovdDk4Ybhq5dmvzcTR2ry/X5eG6XZSQqFakZ2aQViEvoIwoyVslZOl2TL7oUfxb175vE6HePNm9mScZnMHcTjtw2f43SM2zGLz+fbA6AZS0sriMedHa4Zxfivf9UC+CEaGjZKKtNWTuV4j8teNk5V1QS0AjFWqk5bLbXmZEm2TAbw+TJIJn3Ys6cNtbXGz3MyxlotqD17HkBbm/ndQ50ubbcdo/haWrIf07o6OUrHbY7x+++zo7QHHqiRon/txmBnolGvGQnk1YxUFGVYm1cAgKCN9qWRXVWX/VrWeQUvnRLbTPZTwl7rezsjhTCAw4qiRKHVjNRqSX6AbH3JIIDX7Icoh+bm7Jt/9y7w4INOR7OVfmbES7d31+lzCrIuHtN3FF64vTtgIylohwnhTdv0WpJXtH+eIfsZCK/trfLJ3vdeuhgK4DoFy2Rfq+DlU5Kyn/3x2mpSJgWLZP9geukmH5vJvnjJS9WhACYFy9wyhPXiOgV9nkTWURoPH3Yo2UcKXlp7vxn7vrqYFCyS/Uo9L080yn9KknMKO5L8eysvn5LM/s9TktXBpGCRW+YUvDhSuD+n4HAgJrzW90wKFsk+A+61D2Y++fuehw87kuy3effalXr53LLEnIcPO4zM9wlMpYD1dR9qazOoq3M6mvJzyzJnr4zSmBQsknlv5dWScTq3zOd4ZZTGpGCRzMe1XttTbSZz3wOcU9ixZL7Nu9eOaTeT/XSw1/qfScEimQ8f9AVVXtlTbVZfn73JTSrlw/q609Fs5bWL0ZgULJJ5COvl1YxAdp6E/V89TAoW3R8p8PDBCTKP1PRDSq+sJmVSsOj+nILDgRjw2vDViMynhL3W/0wKFrlhT+XFy6Z1Mi8e89rCMSYFi2Q+pvXaeXIjsvZ/JuO9/mdSsEg/XrR6SvL69esIhUJYWloqaxxG7XptT2XErEpUNfvZSCoFbGz4sGuXd1aTlpwUFEUJKIrSryhKSFGU0Hbb3U7fU1m9pn9ychKXLl3C6dOnyxqHUbteO6Y1YjZSqGY/G/HiwrFKlI0rVE7OtfQ3/cYNHx5+eGsV0UzmP+Dz+bCysqJtOQvgLMJhIBzOfnL8fn/JP79Quw0Nfi1G784p6L/bkSN1qKsr3B+V6me/3597n3XptB5fyT9SOpUoG1eonNwWbigbB2Tf/P37H8aXX9ZjZcXoEELfFjBtI/d5K4l5u2trwK5dGTz5ZKJgWTUZ+rGQQvF1dj6I8+db8g7fAqbPrVQ/Z9s1Pnw8cOAe4vHv7fzgspCybFyx3FA2TvfXv2Zw+/aa4WM3b97E3r17AQCvvvoq3nvvPdTX1yOZTOKll17C2bNnbf/8Qu3W1wPNzT8AULismgz9WIhZfGNjwG9+syasaHSin/Pf53x+fy18Pjn61u57bCcp6GXjEsgrG1dgu+vV1ACBgPFjyWQm99jKyv9hePgwBgYGMDMzg6Wla6avK0al2nWLzZW5nOjn/PfZq3yZTGnHoVq9yMMAogCgquoFrWzchc3bzdoYHR3NnDx5ctufFY/Hpd/DMUb7ZI8PkD/GYuKbmJjAmTNnthwPVapsXHjLC4jIFbhOgYgETApEJGBSICIBkwIRCZgUiEjApEBEAiYFIhIwKRCRgEmBiARMCkQkYFIgIgGTAhEJmBSISMCkQEQCJgUiEjApEJGASYGIBEwKRCRgUiAiAZMCEQlKvnGrdjfnELTbuOt3bda2X0b2rs7nVFW9YjdIIqqeSpSNA4DnVFWN2guNiJxQibJxANClKEoXgEShug+Ae8rGbYcx2id7fID8MValbJxW4EWQV99hC60exJz22m0LzLqpbNx2GKN9sscHyB9jxcvGFUgAhuXhFEUZBvCBlhyCtqIjoqqzc/gQBnBYUZQogHNAblTxAYCgoihBAK/ZD5GIqqlSZeOuaP+IyGW4ToGIBEwKRCRgUiAiAZMCEQmYFIhIwKRARAImBSISMCkQkYBJgYgETApEJGBSICIBkwIRCZgUiEjApEBEAiYFIhIwKRCRgEmBiARMCkQkYFIgIgGTAhEJ7NzNWb978wFVVV/L2xaAQTk5InIHWyMF7e7NgU2bhwFc0ZLBMTvtE1H12RopmChUTm6LiYmJCoRARKUqe9m4Ypw5c8ZXjnaIqHzslI0zY1hOjojcwdacgjaKCGoVpvXvwwBCiqKEoJWTIyL38GUyGadjICKJcJ0CEQmYFIhIUIlTkraYLX6SZVHUNvEFASgAlst1hqYU2/WVNvcTVVXVscrghWJUFGUYgApAUVU1bPR6h+MLAVgGEHTyfdZiKfsCQhlHCmaLn2RZFGUWx2EACe1DfNaRyO4z7SvtA3MA2TNETjKMUf+D0xKWk6thC8UHLT6n+7AiCwhlTAoHkM3CgPjLmm2vNsM4VFUNq6oa1f7oHNsDawr1lYLsaWOnmcX4ArJntEIAuqodVB6z9/kCgLOKovwRziatQmz9rciYFNzudQBDTgdhRFGUkEuuRbmgxfm604Fspp1+n0Q2sTo9IqwIGZOCvvgJEBc/mW2vNtM4tOO7STg/rDSLcVnbAz8HZ/fCgHmMl+F8/wHm8YVUVZ1TVfW/AUSrHpU1tv5WZEwKWxY/SbYoyjA+7fvXAfwOzu9BDGPUjoNV7Tl7nApOYxZjGECXtn1StvgAXNDe7y4A/+tgfMiLqawLCLl4iYgEMo4UiMhBTApEJGBSICIBkwIRCZgUiEjApEBEAiYFIhL8P4IEZQA8UOJ3AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] @@ -274,6 +274,74 @@ " ax.legend(['Observed Data', 'Mean'])" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Notes on other Non-Gaussian Likeihoods\n", + "\n", + "The Bernoulli likelihood is special in that we can compute the analytic (approximate) posterior predictive in closed form. That is: $q(\\mathbf y) = E_{q(\\mathbf f)}[ p(y \\mid \\mathbf f) ]$ is a Bernoulli distribution when $q(\\mathbf f)$ is a multivariate Gaussian.\n", + "\n", + "Most other non-Gaussian likelihoods do not admit an analytic (approximate) posterior predictive. To that end, calling `likelihood(model)` will generally return Monte Carlo samples from the posterior predictive." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Type of output: Bernoulli\n", + "Shape of output: torch.Size([101])\n" + ] + } + ], + "source": [ + "# Analytic marginal\n", + "likelihood = gpytorch.likelihoods.BernoulliLikelihood()\n", + "observed_pred = likelihood(model(test_x))\n", + "print(\n", + " f\"Type of output: {observed_pred.__class__.__name__}\\n\"\n", + " f\"Shape of output: {observed_pred.batch_shape + observed_pred.event_shape}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Type of output: Beta\n", + "Shape of output: torch.Size([15, 101])\n" + ] + } + ], + "source": [ + "# Monte Carlo marginal\n", + "likelihood = gpytorch.likelihoods.BetaLikelihood()\n", + "with gpytorch.settings.num_likelihood_samples(15):\n", + " observed_pred = likelihood(model(test_x))\n", + "print(\n", + " f\"Type of output: {observed_pred.__class__.__name__}\\n\"\n", + " f\"Shape of output: {observed_pred.batch_shape + observed_pred.event_shape}\"\n", + ")\n", + "# There are 15 MC samples for each test datapoint" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "See [the Likelihood documentation](http://gpytorch.readthedocs.io/en/stable/likelihoods.html#likelihood) for more details." + ] + }, { "cell_type": "code", "execution_count": null, @@ -285,7 +353,7 @@ "metadata": { "anaconda-cloud": {}, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -299,7 +367,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.3" + "version": "3.8.0" } }, "nbformat": 4, diff --git a/gpytorch/likelihoods/bernoulli_likelihood.py b/gpytorch/likelihoods/bernoulli_likelihood.py index 43942f061..614dc6546 100644 --- a/gpytorch/likelihoods/bernoulli_likelihood.py +++ b/gpytorch/likelihoods/bernoulli_likelihood.py @@ -1,10 +1,13 @@ #!/usr/bin/env python3 import warnings +from typing import Any import torch +from torch import Tensor +from torch.distributions import Bernoulli -from ..distributions import base_distributions +from ..distributions import base_distributions, MultivariateNormal from ..functions import log_normal_cdf from .likelihood import _OneDimensionalLikelihood @@ -21,26 +24,41 @@ class BernoulliLikelihood(_OneDimensionalLikelihood): p(Y=y|f)=\Phi((2y - 1)f) \end{equation*} + .. note:: + BernoulliLikelihood has an analytic marginal distribution. + .. note:: The labels should take values in {0, 1}. """ - def forward(self, function_samples, **kwargs): + has_analytic_marginal: bool = True + + def __init__(self) -> None: + return super().__init__() + + def forward(self, function_samples: Tensor, *args: Any, **kwargs: Any) -> Bernoulli: output_probs = base_distributions.Normal(0, 1).cdf(function_samples) return base_distributions.Bernoulli(probs=output_probs) - def log_marginal(self, observations, function_dist, *args, **kwargs): + def log_marginal( + self, observations: Tensor, function_dist: MultivariateNormal, *args: Any, **kwargs: Any + ) -> Tensor: marginal = self.marginal(function_dist, *args, **kwargs) return marginal.log_prob(observations) - def marginal(self, function_dist, **kwargs): + def marginal(self, function_dist: MultivariateNormal, *args: Any, **kwargs: Any) -> Bernoulli: + """ + :return: Analytic marginal :math:`p(\mathbf y)`. + """ mean = function_dist.mean var = function_dist.variance link = mean.div(torch.sqrt(1 + var)) output_probs = base_distributions.Normal(0, 1).cdf(link) return base_distributions.Bernoulli(probs=output_probs) - def expected_log_prob(self, observations, function_dist, *params, **kwargs): + def expected_log_prob( + self, observations: Tensor, function_dist: MultivariateNormal, *params: Any, **kwargs: Any + ) -> Tensor: if torch.any(observations.eq(-1)): # Remove after 1.0 warnings.warn( diff --git a/gpytorch/likelihoods/beta_likelihood.py b/gpytorch/likelihoods/beta_likelihood.py index 0ec25f38b..ddeedfd25 100644 --- a/gpytorch/likelihoods/beta_likelihood.py +++ b/gpytorch/likelihoods/beta_likelihood.py @@ -1,9 +1,14 @@ #!/usr/bin/env python3 +from typing import Any, Optional + import torch +from torch import Tensor +from torch.distributions import Beta -from ..constraints import Positive +from ..constraints import Interval, Positive from ..distributions import base_distributions +from ..priors import Prior from .likelihood import _OneDimensionalLikelihood @@ -27,16 +32,18 @@ class BetaLikelihood(_OneDimensionalLikelihood): p(y \mid f) = \text{Beta} \left( \sigma(f) s , (1 - \sigma(f)) s\right) :param batch_shape: The batch shape of the learned noise parameter (default: []). - :type batch_shape: torch.Size, optional :param scale_prior: Prior for scale parameter :math:`s`. - :type scale_prior: ~gpytorch.priors.Prior, optional :param scale_constraint: Constraint for scale parameter :math:`s`. - :type scale_constraint: ~gpytorch.constraints.Interval, optional - :var torch.Tensor scale: :math:`s` parameter (scale) + :ivar torch.Tensor scale: :math:`s` parameter (scale) """ - def __init__(self, batch_shape=torch.Size([]), scale_prior=None, scale_constraint=None): + def __init__( + self, + batch_shape: torch.Size = torch.Size([]), + scale_prior: Optional[Prior] = None, + scale_constraint: Optional[Interval] = None, + ) -> None: super().__init__() if scale_constraint is None: @@ -49,19 +56,19 @@ def __init__(self, batch_shape=torch.Size([]), scale_prior=None, scale_constrain self.register_constraint("raw_scale", scale_constraint) @property - def scale(self): + def scale(self) -> Tensor: return self.raw_scale_constraint.transform(self.raw_scale) @scale.setter - def scale(self, value): + def scale(self, value: Tensor) -> None: self._set_scale(value) - def _set_scale(self, value): + def _set_scale(self, value: Tensor) -> None: if not torch.is_tensor(value): value = torch.as_tensor(value).to(self.raw_scale) self.initialize(raw_scale=self.raw_scale_constraint.inverse_transform(value)) - def forward(self, function_samples, **kwargs): + def forward(self, function_samples: Tensor, *args: Any, **kwargs: Any) -> Beta: mixture = torch.sigmoid(function_samples) scale = self.scale alpha = mixture * scale + 1 diff --git a/gpytorch/likelihoods/gaussian_likelihood.py b/gpytorch/likelihoods/gaussian_likelihood.py index 3ce286946..5870a4e64 100644 --- a/gpytorch/likelihoods/gaussian_likelihood.py +++ b/gpytorch/likelihoods/gaussian_likelihood.py @@ -3,13 +3,16 @@ import math import warnings from copy import deepcopy -from typing import Any, Optional +from typing import Any, Optional, Tuple, Union import torch -from linear_operator.operators import ZeroLinearOperator +from linear_operator.operators import LinearOperator, ZeroLinearOperator from torch import Tensor +from torch.distributions import Distribution, Normal +from ..constraints import Interval from ..distributions import base_distributions, MultivariateNormal +from ..priors import Prior from ..utils.warnings import GPInputWarning from .likelihood import Likelihood from .noise_models import FixedGaussianNoise, HomoskedasticNoise, Noise @@ -18,8 +21,9 @@ class _GaussianLikelihoodBase(Likelihood): """Base class for Gaussian Likelihoods, supporting general heteroskedastic noise models.""" - def __init__(self, noise_covar: Noise, **kwargs: Any) -> None: + has_analytic_marginal = True + def __init__(self, noise_covar: Union[Noise, FixedGaussianNoise], **kwargs: Any) -> None: super().__init__() param_transform = kwargs.get("param_transform") if param_transform is not None: @@ -31,7 +35,7 @@ def __init__(self, noise_covar: Noise, **kwargs: Any) -> None: self.noise_covar = noise_covar - def _shaped_noise_covar(self, base_shape: torch.Size, *params: Any, **kwargs: Any): + def _shaped_noise_covar(self, base_shape: torch.Size, *params: Any, **kwargs: Any) -> Union[Tensor, LinearOperator]: return self.noise_covar(*params, shape=base_shape, **kwargs) def expected_log_prob(self, target: Tensor, input: MultivariateNormal, *params: Any, **kwargs: Any) -> Tensor: @@ -42,13 +46,13 @@ def expected_log_prob(self, target: Tensor, input: MultivariateNormal, *params: # Potentially reshape the noise to deal with the multitask case noise = noise.view(*noise.shape[:-1], *input.event_shape) - res = ((target - mean) ** 2 + variance) / noise + noise.log() + math.log(2 * math.pi) + res = ((target - mean).square() + variance) / noise + noise.log() + math.log(2 * math.pi) res = res.mul(-0.5) if num_event_dim > 1: # Do appropriate summation for multitask Gaussian likelihoods res = res.sum(list(range(-1, -num_event_dim, -1))) return res - def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> base_distributions.Normal: + def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> Normal: noise = self._shaped_noise_covar(function_samples.shape, *params, **kwargs).diagonal(dim1=-1, dim2=-2) return base_distributions.Normal(function_samples, noise.sqrt()) @@ -86,17 +90,24 @@ class GaussianLikelihood(_GaussianLikelihoodBase): .. note:: This likelihood can be used for exact or approximate inference. + .. note:: + GaussianLikelihood has an analytic marginal distribution. + :param noise_prior: Prior for noise parameter :math:`\sigma^2`. - :type noise_prior: ~gpytorch.priors.Prior, optional :param noise_constraint: Constraint for noise parameter :math:`\sigma^2`. - :type noise_constraint: ~gpytorch.constraints.Interval, optional :param batch_shape: The batch shape of the learned noise parameter (default: []). - :type batch_shape: torch.Size, optional + :param kwargs: - :var torch.Tensor noise: :math:`\sigma^2` parameter (noise) + :ivar torch.Tensor noise: :math:`\sigma^2` parameter (noise) """ - def __init__(self, noise_prior=None, noise_constraint=None, batch_shape=torch.Size(), **kwargs): + def __init__( + self, + noise_prior: Optional[Prior] = None, + noise_constraint: Optional[Interval] = None, + batch_shape: torch.Size = torch.Size(), + **kwargs: Any, + ) -> None: noise_covar = HomoskedasticNoise( noise_prior=noise_prior, noise_constraint=noise_constraint, batch_shape=batch_shape ) @@ -118,6 +129,12 @@ def raw_noise(self) -> Tensor: def raw_noise(self, value: Tensor) -> None: self.noise_covar.initialize(raw_noise=value) + def marginal(self, function_dist: MultivariateNormal, *args: Any, **kwargs: Any) -> MultivariateNormal: + """ + :return: Analytic marginal :math:`p(\mathbf y)`. + """ + return super().marginal(function_dist, *args, **kwargs) + class GaussianLikelihoodWithMissingObs(GaussianLikelihood): r""" @@ -141,28 +158,39 @@ class GaussianLikelihoodWithMissingObs(GaussianLikelihood): :type batch_shape: torch.Size, optional :var torch.Tensor noise: :math:`\sigma^2` parameter (noise) + + .. note:: + GaussianLikelihoodWithMissingObs has an analytic marginal distribution. """ - MISSING_VALUE_FILL = -999.0 + MISSING_VALUE_FILL: float = -999.0 - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - def _get_masked_obs(self, x): + def _get_masked_obs(self, x: Tensor) -> Tuple[Tensor, Tensor]: missing_idx = x.isnan() x_masked = x.masked_fill(missing_idx, self.MISSING_VALUE_FILL) return missing_idx, x_masked - def expected_log_prob(self, target, input, *params, **kwargs): + def expected_log_prob(self, target: Tensor, input: MultivariateNormal, *params: Any, **kwargs: Any) -> Tensor: missing_idx, target = self._get_masked_obs(target) res = super().expected_log_prob(target, input, *params, **kwargs) return res * ~missing_idx - def log_marginal(self, observations, function_dist, *params, **kwargs): + def log_marginal( + self, observations: Tensor, function_dist: MultivariateNormal, *params: Any, **kwargs: Any + ) -> Tensor: missing_idx, observations = self._get_masked_obs(observations) res = super().log_marginal(observations, function_dist, *params, **kwargs) return res * ~missing_idx + def marginal(self, function_dist: MultivariateNormal, *args: Any, **kwargs: Any) -> MultivariateNormal: + """ + :return: Analytic marginal :math:`p(\mathbf y)`. + """ + return super().marginal(function_dist, *args, **kwargs) + class FixedNoiseGaussianLikelihood(_GaussianLikelihoodBase): r""" @@ -186,6 +214,9 @@ class FixedNoiseGaussianLikelihood(_GaussianLikelihoodBase): :var torch.Tensor noise: :math:`\sigma^2` parameter (noise) + .. note:: + FixedNoiseGaussianLikelihood has an analytic marginal distribution. + Example: >>> train_x = torch.randn(55, 2) >>> noises = torch.ones(55) * 0.01 @@ -206,14 +237,13 @@ def __init__( ) -> None: super().__init__(noise_covar=FixedGaussianNoise(noise=noise)) + self.second_noise_covar: Optional[HomoskedasticNoise] = None if learn_additional_noise: noise_prior = kwargs.get("noise_prior", None) noise_constraint = kwargs.get("noise_constraint", None) self.second_noise_covar = HomoskedasticNoise( noise_prior=noise_prior, noise_constraint=noise_constraint, batch_shape=batch_shape ) - else: - self.second_noise_covar = None @property def noise(self) -> Tensor: @@ -224,9 +254,9 @@ def noise(self, value: Tensor) -> None: self.noise_covar.initialize(noise=value) @property - def second_noise(self) -> Tensor: + def second_noise(self) -> Union[float, Tensor]: if self.second_noise_covar is None: - return 0 + return 0.0 else: return self.second_noise_covar.noise @@ -239,11 +269,11 @@ def second_noise(self, value: Tensor) -> None: ) self.second_noise_covar.initialize(noise=value) - def get_fantasy_likelihood(self, **kwargs): + def get_fantasy_likelihood(self, **kwargs: Any) -> "FixedNoiseGaussianLikelihood": if "noise" not in kwargs: raise RuntimeError("FixedNoiseGaussianLikelihood.fantasize requires a `noise` kwarg") old_noise_covar = self.noise_covar - self.noise_covar = None + self.noise_covar = None # pyre-fixme[8] fantasy_liklihood = deepcopy(self) self.noise_covar = old_noise_covar @@ -254,7 +284,7 @@ def get_fantasy_likelihood(self, **kwargs): fantasy_liklihood.noise_covar = FixedGaussianNoise(noise=torch.cat([old_noise, new_noise], -1)) return fantasy_liklihood - def _shaped_noise_covar(self, base_shape: torch.Size, *params: Any, **kwargs: Any): + def _shaped_noise_covar(self, base_shape: torch.Size, *params: Any, **kwargs: Any) -> Union[Tensor, LinearOperator]: if len(params) > 0: # we can infer the shape from the params shape = None @@ -275,6 +305,12 @@ def _shaped_noise_covar(self, base_shape: torch.Size, *params: Any, **kwargs: An return res + def marginal(self, function_dist: MultivariateNormal, *args: Any, **kwargs: Any) -> MultivariateNormal: + """ + :return: Analytic marginal :math:`p(\mathbf y)`. + """ + return super().marginal(function_dist, *args, **kwargs) + class DirichletClassificationLikelihood(FixedNoiseGaussianLikelihood): """ @@ -284,18 +320,18 @@ class DirichletClassificationLikelihood(FixedNoiseGaussianLikelihood): .. note:: This likelihood can be used for exact or approximate inference. - :param targets: classification labels. - :type targets: torch.Tensor (N). - :param alpha_epsilon: tuning parameter for the scaling of the likeihood targets. We'd suggest 0.01 or setting + :param targets: (... x N) Classification labels. + :param alpha_epsilon: Tuning parameter for the scaling of the likeihood targets. We'd suggest 0.01 or setting via cross-validation. - :type alpha_epsilon: int. - :param learn_additional_noise: Set to true if you additionally want to learn added diagonal noise, similar to GaussianLikelihood. - :type learn_additional_noise: bool, optional :param batch_shape: The batch shape of the learned noise parameter (default []) if :obj:`learn_additional_noise=True`. - :type batch_shape: torch.Size, optional + + :ivar torch.Tensor noise: :math:`\sigma^2` parameter (noise) + + .. note:: + DirichletClassificationLikelihood has an analytic marginal distribution. Example: >>> train_x = torch.randn(55, 1) @@ -308,7 +344,9 @@ class DirichletClassificationLikelihood(FixedNoiseGaussianLikelihood): >>> pred_y = likelihood(gp_model(test_x), targets=labels) """ - def _prepare_targets(self, targets, alpha_epsilon=0.01, dtype=torch.float): + def _prepare_targets( + self, targets: Tensor, alpha_epsilon: float = 0.01, dtype: torch.dtype = torch.float + ) -> Tuple[Tensor, Tensor, int]: num_classes = int(targets.max() + 1) # set alpha = \alpha_\epsilon alpha = alpha_epsilon * torch.ones(targets.shape[-1], num_classes, device=targets.device, dtype=dtype) @@ -317,7 +355,7 @@ def _prepare_targets(self, targets, alpha_epsilon=0.01, dtype=torch.float): alpha[torch.arange(len(targets)), targets] = alpha[torch.arange(len(targets)), targets] + 1.0 # sigma^2 = log(1 / alpha + 1) - sigma2_i = torch.log(1 / alpha + 1.0) + sigma2_i = torch.log(alpha.reciprocal() + 1.0) # y = log(alpha) - 0.5 * sigma^2 transformed_targets = alpha.log() - 0.5 * sigma2_i @@ -327,12 +365,12 @@ def _prepare_targets(self, targets, alpha_epsilon=0.01, dtype=torch.float): def __init__( self, targets: Tensor, - alpha_epsilon: int = 0.01, + alpha_epsilon: float = 0.01, learn_additional_noise: Optional[bool] = False, - batch_shape: Optional[torch.Size] = torch.Size(), - dtype: Optional[torch.dtype] = torch.float, - **kwargs, - ): + batch_shape: torch.Size = torch.Size(), + dtype: torch.dtype = torch.float, + **kwargs: Any, + ) -> None: sigma2_labels, transformed_targets, num_classes = self._prepare_targets( targets, alpha_epsilon=alpha_epsilon, dtype=dtype ) @@ -342,19 +380,19 @@ def __init__( batch_shape=torch.Size((num_classes,)), **kwargs, ) - self.transformed_targets = transformed_targets.transpose(-2, -1) - self.num_classes = num_classes - self.targets = targets - self.alpha_epsilon = alpha_epsilon + self.transformed_targets: Tensor = transformed_targets.transpose(-2, -1) + self.num_classes: int = num_classes + self.targets: Tensor = targets + self.alpha_epsilon: float = alpha_epsilon - def get_fantasy_likelihood(self, **kwargs): + def get_fantasy_likelihood(self, **kwargs: Any) -> "DirichletClassificationLikelihood": # we assume that the number of classes does not change. if "targets" not in kwargs: raise RuntimeError("FixedNoiseGaussianLikelihood.fantasize requires a `targets` kwarg") old_noise_covar = self.noise_covar - self.noise_covar = None + self.noise_covar = None # pyre-fixme[8] fantasy_liklihood = deepcopy(self) self.noise_covar = old_noise_covar @@ -369,10 +407,16 @@ def get_fantasy_likelihood(self, **kwargs): fantasy_liklihood.noise_covar = FixedGaussianNoise(noise=torch.cat([old_noise, new_noise], -1)) return fantasy_liklihood - def __call__(self, *args, **kwargs): + def marginal(self, function_dist: MultivariateNormal, *args: Any, **kwargs: Any) -> MultivariateNormal: + """ + :return: Analytic marginal :math:`p(\mathbf y)`. + """ + return super().marginal(function_dist, *args, **kwargs) + + def __call__(self, input: Union[Tensor, MultivariateNormal], *args: Any, **kwargs: Any) -> Distribution: if "targets" in kwargs: targets = kwargs.pop("targets") dtype = self.transformed_targets.dtype new_noise, _, _ = self._prepare_targets(targets, dtype=dtype) kwargs["noise"] = new_noise - return super().__call__(*args, **kwargs) + return super().__call__(input, *args, **kwargs) diff --git a/gpytorch/likelihoods/laplace_likelihood.py b/gpytorch/likelihoods/laplace_likelihood.py index 8b151d0c7..c40f8b7cc 100644 --- a/gpytorch/likelihoods/laplace_likelihood.py +++ b/gpytorch/likelihoods/laplace_likelihood.py @@ -1,9 +1,14 @@ #!/usr/bin/env python3 +from typing import Any, Optional + import torch +from torch import Tensor +from torch.distributions import Laplace -from ..constraints import Positive +from ..constraints import Interval, Positive from ..distributions import base_distributions +from ..priors import Prior from .likelihood import _OneDimensionalLikelihood @@ -13,16 +18,18 @@ class LaplaceLikelihood(_OneDimensionalLikelihood): It has one learnable parameter: :math:`\sigma` - the noise :param batch_shape: The batch shape of the learned noise parameter (default: []). - :type batch_shape: torch.Size, optional :param noise_prior: Prior for noise parameter :math:`\sigma`. - :type noise_prior: ~gpytorch.priors.Prior, optional :param noise_constraint: Constraint for noise parameter :math:`\sigma`. - :type noise_constraint: ~gpytorch.constraints.Interval, optional :var torch.Tensor noise: :math:`\sigma` parameter (noise) """ - def __init__(self, batch_shape=torch.Size([]), noise_prior=None, noise_constraint=None): + def __init__( + self, + batch_shape: torch.Size = torch.Size([]), + noise_prior: Optional[Prior] = None, + noise_constraint: Optional[Interval] = None, + ) -> None: super().__init__() if noise_constraint is None: @@ -36,17 +43,17 @@ def __init__(self, batch_shape=torch.Size([]), noise_prior=None, noise_constrain self.register_constraint("raw_noise", noise_constraint) @property - def noise(self): + def noise(self) -> Tensor: return self.raw_noise_constraint.transform(self.raw_noise) @noise.setter - def noise(self, value): + def noise(self, value: Tensor) -> None: self._set_noise(value) - def _set_noise(self, value): + def _set_noise(self, value: Tensor) -> None: if not torch.is_tensor(value): value = torch.as_tensor(value).to(self.raw_noise) self.initialize(raw_noise=self.raw_noise_constraint.inverse_transform(value)) - def forward(self, function_samples, **kwargs): + def forward(self, function_samples: Tensor, *args: Any, **kwargs: Any) -> Laplace: return base_distributions.Laplace(loc=function_samples, scale=self.noise.sqrt()) diff --git a/gpytorch/likelihoods/likelihood.py b/gpytorch/likelihoods/likelihood.py index 81337ed86..0a5b69479 100644 --- a/gpytorch/likelihoods/likelihood.py +++ b/gpytorch/likelihoods/likelihood.py @@ -4,8 +4,11 @@ import warnings from abc import ABC, abstractmethod from copy import deepcopy +from typing import Any, Dict, Optional, Union import torch +from torch import Tensor +from torch.distributions import Distribution as _Distribution from .. import settings from ..distributions import base_distributions, MultivariateNormal @@ -15,11 +18,15 @@ class _Likelihood(Module, ABC): - def __init__(self, max_plate_nesting=1): + has_analytic_marginal: bool = False + + def __init__(self, max_plate_nesting: int = 1) -> None: super().__init__() - self.max_plate_nesting = max_plate_nesting + self.max_plate_nesting: int = max_plate_nesting - def _draw_likelihood_samples(self, function_dist, *args, sample_shape=None, **kwargs): + def _draw_likelihood_samples( + self, function_dist: MultivariateNormal, *args: Any, sample_shape: Optional[torch.Size] = None, **kwargs: Any + ) -> _Distribution: if sample_shape is None: sample_shape = torch.Size( [settings.num_likelihood_samples.value()] @@ -34,32 +41,36 @@ def _draw_likelihood_samples(self, function_dist, *args, sample_shape=None, **kw function_samples = function_dist.rsample(sample_shape) return self.forward(function_samples, *args, **kwargs) - def expected_log_prob(self, observations, function_dist, *args, **kwargs): + def expected_log_prob( + self, observations: Tensor, function_dist: MultivariateNormal, *args: Any, **kwargs: Any + ) -> Tensor: likelihood_samples = self._draw_likelihood_samples(function_dist, *args, **kwargs) res = likelihood_samples.log_prob(observations).mean(dim=0) return res @abstractmethod - def forward(self, function_samples, *args, **kwargs): + def forward(self, function_samples: Tensor, *args: Any, **kwargs: Any) -> _Distribution: raise NotImplementedError - def get_fantasy_likelihood(self, **kwargs): + def get_fantasy_likelihood(self, **kwargs: Any) -> "_Likelihood": return deepcopy(self) - def log_marginal(self, observations, function_dist, *args, **kwargs): + def log_marginal( + self, observations: Tensor, function_dist: MultivariateNormal, *args: Any, **kwargs: Any + ) -> Tensor: likelihood_samples = self._draw_likelihood_samples(function_dist, *args, **kwargs) log_probs = likelihood_samples.log_prob(observations) res = log_probs.sub(math.log(log_probs.size(0))).logsumexp(dim=0) return res - def marginal(self, function_dist, *args, **kwargs): + def marginal(self, function_dist: MultivariateNormal, *args: Any, **kwargs: Any) -> _Distribution: res = self._draw_likelihood_samples(function_dist, *args, **kwargs) return res - def __call__(self, input, *args, **kwargs): + def __call__(self, input: Union[Tensor, MultivariateNormal], *args: Any, **kwargs: Any) -> _Distribution: # Conditional if torch.is_tensor(input): - return super().__call__(input, *args, **kwargs) + return super().__call__(input, *args, **kwargs) # pyre-ignore[7] # Marginal elif isinstance(input, MultivariateNormal): return self.marginal(input, *args, **kwargs) @@ -99,22 +110,17 @@ class Likelihood(_Likelihood): requires a forward method that computes the conditional distribution :math:`p(y \mid f(\mathbf x))`. - Calling this object does one of two things: - - - If likelihood is called with a :class:`torch.Tensor` object, then it is - assumed that the input is samples from :math:`f(\mathbf x)`. This - returns the *conditional* distribution :math:`p(y|f(\mathbf x))`. - - If likelihood is called with a :class:`~gpytorch.distribution.MultivariateNormal` object, - then it is assumed that the input is the distribution :math:`f(\mathbf x)`. - This returns the *marginal* distribution :math:`p(y|\mathbf x)`. - - :param max_plate_nesting: (For Pyro integration only). How many batch dimensions are in the function. - This should be modified if the likelihood uses plated random variables. - :type max_plate_nesting: int, default=1 + :param bool has_analytic_marginal: Whether or not the marginal distribution :math:`p(\mathbf y)` + can be computed in closed form. (See :meth:`~gpytorch.likelihoods.Likelihood.__call__` docstring.) + :param max_plate_nesting: (For Pyro integration only.) How many batch dimensions are in the function. + This should be modified if the likelihood uses plated random variables. (Default = 1) + This should be modified if the likelihood uses plated random variables. (Default = 1) + :param str name_prefix: (For Pyro integration only.) Prefix to assign to named Pyro latent variables. + :param int num_data: (For Pyro integration only.) Total amount of observations. """ @property - def num_data(self): + def num_data(self) -> int: if hasattr(self, "_num_data"): return self._num_data else: @@ -124,21 +130,23 @@ def num_data(self): return "" @num_data.setter - def num_data(self, val): + def num_data(self, val: int) -> None: self._num_data = val @property - def name_prefix(self): + def name_prefix(self) -> str: if hasattr(self, "_name_prefix"): return self._name_prefix else: return "" @name_prefix.setter - def name_prefix(self, val): + def name_prefix(self, val: str) -> None: self._name_prefix = val - def _draw_likelihood_samples(self, function_dist, *args, sample_shape=None, **kwargs): + def _draw_likelihood_samples( + self, function_dist: _Distribution, *args: Any, sample_shape: Optional[torch.Size] = None, **kwargs: Any + ) -> _Distribution: if self.training: num_event_dims = len(function_dist.event_shape) function_dist = base_distributions.Normal(function_dist.mean, function_dist.variance.sqrt()) @@ -160,7 +168,9 @@ def _draw_likelihood_samples(self, function_dist, *args, sample_shape=None, **kw function_samples = function_samples.squeeze(-len(function_dist.event_shape) - 1) return self.forward(function_samples, *args, **kwargs) - def expected_log_prob(self, observations, function_dist, *args, **kwargs): + def expected_log_prob( + self, observations: Tensor, function_dist: MultivariateNormal, *args: Any, **kwargs: Any + ) -> Tensor: r""" (Used by :obj:`~gpytorch.mlls.VariationalELBO` for variational inference.) @@ -170,36 +180,37 @@ def expected_log_prob(self, observations, function_dist, *args, **kwargs): \sum_{\mathbf x, y} \mathbb{E}_{q\left( f(\mathbf x) \right)} \left[ \log p \left( y \mid f(\mathbf x) \right) \right] - :param torch.Tensor observations: Values of :math:`y`. - :param ~gpytorch.distributions.MultivariateNormal function_dist: Distribution for :math:`f(x)`. + :param observations: Values of :math:`y`. + :param function_dist: Distribution for :math:`f(x)`. :param args: Additional args (passed to the foward function). :param kwargs: Additional kwargs (passed to the foward function). - :rtype: torch.Tensor """ return super().expected_log_prob(observations, function_dist, *args, **kwargs) @abstractmethod - def forward(self, function_samples, *args, data={}, **kwargs): + def forward( + self, function_samples: Tensor, *args: Any, data: Dict[str, Tensor] = {}, **kwargs: Any + ) -> _Distribution: r""" Computes the conditional distribution :math:`p(\mathbf y \mid \mathbf f, \ldots)` that defines the likelihood. - :param torch.Tensor function_samples: Samples from the function (:math:`\mathbf f`) - :param data: Additional variables that the likelihood needs to condition + :param function_samples: Samples from the function (:math:`\mathbf f`) + :param data: (Pyro integration only.) Additional variables that the likelihood needs to condition on. The keys of the dictionary will correspond to Pyro sample sites in the likelihood's model/guide. - :type data: dict {str: torch.Tensor}, optional - Pyro integration only :param args: Additional args :param kwargs: Additional kwargs - :rtype: :obj:`Distribution` (with same shape as function_samples ) """ raise NotImplementedError - def get_fantasy_likelihood(self, **kwargs): + def get_fantasy_likelihood(self, **kwargs: Any) -> "_Likelihood": """""" return super().get_fantasy_likelihood(**kwargs) - def log_marginal(self, observations, function_dist, *args, **kwargs): + def log_marginal( + self, observations: Tensor, function_dist: MultivariateNormal, *args: Any, **kwargs: Any + ) -> Tensor: r""" (Used by :obj:`~gpytorch.mlls.PredictiveLogLikelihood` for approximate inference.) @@ -212,15 +223,14 @@ def log_marginal(self, observations, function_dist, *args, **kwargs): Note that this differs from :meth:`expected_log_prob` because the :math:`log` is on the outside of the expectation. - :param torch.Tensor observations: Values of :math:`y`. - :param ~gpytorch.distributions.MultivariateNormal function_dist: Distribution for :math:`f(x)`. + :param observations: Values of :math:`y`. + :param function_dist: Distribution for :math:`f(x)`. :param args: Additional args (passed to the foward function). :param kwargs: Additional kwargs (passed to the foward function). - :rtype: torch.Tensor """ return super().log_marginal(observations, function_dist, *args, **kwargs) - def marginal(self, function_dist, *args, **kwargs): + def marginal(self, function_dist: MultivariateNormal, *args: Any, **kwargs: Any) -> _Distribution: r""" Computes a predictive distribution :math:`p(y^* | \mathbf x^*)` given either a posterior distribution :math:`p(\mathbf f | \mathcal D, \mathbf x)` or a @@ -232,31 +242,30 @@ def marginal(self, function_dist, *args, **kwargs): should usually be a :obj:`~gpytorch.distributions.MultivariateNormal` specified by the mean and (co)variance of :math:`p(\mathbf f|...)`. - :param ~gpytorch.distributions.MultivariateNormal function_dist: Distribution for :math:`f(x)`. + :param function_dist: Distribution for :math:`f(x)`. :param args: Additional args (passed to the foward function). :param kwargs: Additional kwargs (passed to the foward function). :return: The marginal distribution, or samples from it. - :rtype: ~gpytorch.distributions.Distribution """ return super().marginal(function_dist, *args, **kwargs) - def pyro_guide(self, function_dist, target, *args, **kwargs): + def pyro_guide(self, function_dist: MultivariateNormal, target: Tensor, *args: Any, **kwargs: Any) -> None: r""" (For Pyro integration only). Part of the guide function for the likelihood. This should be re-defined if the likelihood contains any latent variables that need to be infered. - :param ~gpytorch.distributions.MultivariateNormal function_dist: Distribution of latent function + :param function_dist: Distribution of latent function :math:`q(\mathbf f)`. - :param torch.Tensor target: Observed :math:`\mathbf y`. - :param args: Additional args (for :meth:`~forward`). - :param kwargs: Additional kwargs (for :meth:`~forward`). + :param target: Observed :math:`\mathbf y`. + :param args: Additional args (passed to the foward function). + :param kwargs: Additional kwargs (passed to the foward function). """ with pyro.plate(self.name_prefix + ".data_plate", dim=-1): pyro.sample(self.name_prefix + ".f", function_dist) - def pyro_model(self, function_dist, target, *args, **kwargs): + def pyro_model(self, function_dist: MultivariateNormal, target: Tensor, *args: Any, **kwargs: Any) -> Tensor: r""" (For Pyro integration only). @@ -264,23 +273,83 @@ def pyro_model(self, function_dist, target, *args, **kwargs): It should return the This should be re-defined if the likelihood contains any latent variables that need to be infered. - :param ~gpytorch.distributions.MultivariateNormal function_dist: Distribution of latent function + :param function_dist: Distribution of latent function :math:`p(\mathbf f)`. - :param torch.Tensor target: Observed :math:`\mathbf y`. - :param args: Additional args (for :meth:`~forward`). - :param kwargs: Additional kwargs (for :meth:`~forward`). + :param target: Observed :math:`\mathbf y`. + :param args: Additional args (passed to the foward function). + :param kwargs: Additional kwargs (passed to the foward function). """ with pyro.plate(self.name_prefix + ".data_plate", dim=-1): function_samples = pyro.sample(self.name_prefix + ".f", function_dist) output_dist = self(function_samples, *args, **kwargs) return self.sample_target(output_dist, target) - def sample_target(self, output_dist, target): + def sample_target(self, output_dist: MultivariateNormal, target: Tensor) -> Tensor: scale = (self.num_data or output_dist.batch_shape[-1]) / output_dist.batch_shape[-1] - with pyro.poutine.scale(scale=scale): + with pyro.poutine.scale(scale=scale): # pyre-ignore[16] return pyro.sample(self.name_prefix + ".y", output_dist, obs=target) - def __call__(self, input, *args, **kwargs): + def __call__(self, input: Union[Tensor, MultivariateNormal], *args: Any, **kwargs: Any) -> _Distribution: + r""" + Calling this object does one of two things: + + 1. If likelihood is called with a :class:`torch.Tensor` object, then it is + assumed that the input is samples from :math:`f(\mathbf x)`. This + returns the *conditional* distribution :math:`p(y|f(\mathbf x))`. + + .. code-block:: python + + f = torch.randn(20) + likelihood = gpytorch.likelihoods.GaussianLikelihood() + conditional = likelihood(f) + print(type(conditional), conditional.batch_shape, conditional.event_shape) + # >>> torch.distributions.Normal, torch.Size([20]), torch.Size([]) + + 2. If likelihood is called with a :class:`~gpytorch.distribution.MultivariateNormal` object, + then it is assumed that the input is the distribution :math:`f(\mathbf x)`. + This returns the *marginal* distribution :math:`p(y|\mathbf x)`. + + The form of the marginal distribution depends on the likelihood. + For :class:`~gpytorch.likelihoods.BernoulliLikelihood` and + :class:`~gpytorch.likelihoods.GaussianLikelihood` objects, the marginal distribution + can be computed analytically, and the likelihood returns the analytic distribution. + For most other likelihoods, there is no analytic form for the marginal, + and so the likelihood instead returns a batch of Monte Carlo samples from the marginal. + + .. code-block:: python + + mean = torch.randn(20) + covar = linear_operator.operators.DiagLinearOperator(torch.ones(20)) + f = gpytorch.distributions.MultivariateNormal(mean, covar) + + # Analytic marginal computation - Bernoulli and Gaussian likelihoods only + analytic_marginal_likelihood = gpytorch.likelihoods.GaussianLikelihood() + marginal = analytic_marginal_likeihood(f) + print(type(marginal), marginal.batch_shape, marginal.event_shape) + # >>> gpytorch.distributions.MultivariateNormal, torch.Size([]), torch.Size([20]) + + # MC marginal computation - all other likelihoods + mc_marginal_likelihood = gpytorch.likelihoods.BetaLikelihood() + with gpytorch.settings.num_likelihood_samples(15): + marginal = analytic_marginal_likeihood(f) + print(type(marginal), marginal.batch_shape, marginal.event_shape) + # >>> torch.distributions.Beta, torch.Size([15, 20]), torch.Size([]) + # (The batch_shape of torch.Size([15, 20]) represents 15 MC samples for 20 data points. + + .. note:: + + If a Likelihood supports analytic marginals, the :attr:`has_analytic_marginal` property will be True. + If a Likelihood does not support analytic marginals, you can set the number of Monte Carlo + samples using the :class:`gpytorch.settings.num_likelihood_samples` context manager. + + :param input: Either a (... x N) sample from :math:`\mathbf f` + or a (... x N) MVN distribution of :math:`\mathbf f`. + :param args: Additional args (passed to the foward function). + :param kwargs: Additional kwargs (passed to the foward function). + :return: Either a conditional :math:`p(\mathbf y \mid \mathbf f)` + or marginal :math:`p(\mathbf y)` + based on whether :attr:`input` is a Tensor or a MultivariateNormal (see above). + """ # Conditional if torch.is_tensor(input): return super().__call__(input, *args, **kwargs) @@ -288,14 +357,14 @@ def __call__(self, input, *args, **kwargs): elif any( [ isinstance(input, MultivariateNormal), - isinstance(input, pyro.distributions.Normal), + isinstance(input, pyro.distributions.Normal), # pyre-ignore[16] ( - isinstance(input, pyro.distributions.Independent) - and isinstance(input.base_dist, pyro.distributions.Normal) + isinstance(input, pyro.distributions.Independent) # pyre-ignore[16] + and isinstance(input.base_dist, pyro.distributions.Normal) # pyre-ignore[16] ), ] ): - return self.marginal(input, *args, **kwargs) + return self.marginal(input, *args, **kwargs) # pyre-ignore[6] # Error else: raise RuntimeError( @@ -307,21 +376,21 @@ def __call__(self, input, *args, **kwargs): class Likelihood(_Likelihood): @property - def num_data(self): + def num_data(self) -> int: warnings.warn("num_data is only used for likehoods that are integrated with Pyro.", RuntimeWarning) return 0 @num_data.setter - def num_data(self, val): + def num_data(self, val: int) -> None: warnings.warn("num_data is only used for likehoods that are integrated with Pyro.", RuntimeWarning) @property - def name_prefix(self): + def name_prefix(self) -> str: warnings.warn("name_prefix is only used for likehoods that are integrated with Pyro.", RuntimeWarning) return "" @name_prefix.setter - def name_prefix(self, val): + def name_prefix(self, val: str) -> None: warnings.warn("name_prefix is only used for likehoods that are integrated with Pyro.", RuntimeWarning) @@ -334,16 +403,20 @@ class _OneDimensionalLikelihood(Likelihood, ABC): by using 1D Gauss-Hermite quadrature. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.quadrature = GaussHermiteQuadrature1D() - def expected_log_prob(self, observations, function_dist, *args, **kwargs): + def expected_log_prob( + self, observations: Tensor, function_dist: MultivariateNormal, *args: Any, **kwargs: Any + ) -> Tensor: log_prob_lambda = lambda function_samples: self.forward(function_samples).log_prob(observations) log_prob = self.quadrature(log_prob_lambda, function_dist) return log_prob - def log_marginal(self, observations, function_dist, *args, **kwargs): + def log_marginal( + self, observations: Tensor, function_dist: MultivariateNormal, *args: Any, **kwargs: Any + ) -> Tensor: prob_lambda = lambda function_samples: self.forward(function_samples).log_prob(observations).exp() prob = self.quadrature(prob_lambda, function_dist) return prob.log() diff --git a/gpytorch/likelihoods/multitask_gaussian_likelihood.py b/gpytorch/likelihoods/multitask_gaussian_likelihood.py index 50c13c50a..6864ab801 100644 --- a/gpytorch/likelihoods/multitask_gaussian_likelihood.py +++ b/gpytorch/likelihoods/multitask_gaussian_likelihood.py @@ -13,13 +13,14 @@ RootLinearOperator, ) from torch import Tensor +from torch.distributions import Normal -from ..constraints import GreaterThan +from ..constraints import GreaterThan, Interval from ..distributions import base_distributions, MultitaskMultivariateNormal from ..lazy import LazyEvaluatedKernelTensor from ..likelihoods import _GaussianLikelihoodBase, Likelihood -from ..module import Module from ..priors import Prior +from .noise_models import FixedGaussianNoise, Noise class _MultitaskGaussianLikelihoodBase(_GaussianLikelihoodBase): @@ -39,17 +40,17 @@ class _MultitaskGaussianLikelihoodBase(_GaussianLikelihoodBase): def __init__( self, num_tasks: int, - noise_covar: Module, - rank: Optional[int] = 0, + noise_covar: Union[Noise, FixedGaussianNoise], + rank: int = 0, task_correlation_prior: Optional[Prior] = None, batch_shape: torch.Size = torch.Size(), - ): + ) -> None: super().__init__(noise_covar=noise_covar) if rank != 0: if rank > num_tasks: raise ValueError(f"Cannot have rank ({rank}) greater than num_tasks ({num_tasks})") tidcs = torch.tril_indices(num_tasks, rank, dtype=torch.long) - self.tidcs = tidcs[:, 1:] # (1, 1) must be 1.0, no need to parameterize this + self.tidcs: Tensor = tidcs[:, 1:] # (1, 1) must be 1.0, no need to parameterize this task_noise_corr = torch.randn(*batch_shape, self.tidcs.size(-1)) self.register_parameter("task_noise_corr", torch.nn.Parameter(task_noise_corr)) if task_correlation_prior is not None: @@ -70,7 +71,9 @@ def _eval_corr_matrix(self) -> Tensor: C = Cfac / Cfac.pow(2).sum(dim=-1, keepdim=True).sqrt() return C @ C.transpose(-1, -2) - def marginal(self, function_dist: MultitaskMultivariateNormal, *params, **kwargs) -> MultitaskMultivariateNormal: + def marginal( + self, function_dist: MultitaskMultivariateNormal, *params: Any, **kwargs: Any + ) -> MultitaskMultivariateNormal: # pyre-ignore[14] r""" If :math:`\text{rank} = 0`, adds the task noises to the diagonal of the covariance matrix of the supplied @@ -111,7 +114,7 @@ def marginal(self, function_dist: MultitaskMultivariateNormal, *params, **kwargs return function_dist.__class__(mean, covar, interleaved=function_dist._interleaved) def _shaped_noise_covar( - self, shape: torch.Size, add_noise: Optional[bool] = True, interleaved: bool = True, *params, **kwargs + self, shape: torch.Size, add_noise: Optional[bool] = True, interleaved: bool = True, *params: Any, **kwargs: Any ) -> LinearOperator: if not self.has_task_noise: noise = ConstantDiagLinearOperator(self.noise, diag_shape=shape[-2] * self.num_tasks) @@ -131,7 +134,7 @@ def _shaped_noise_covar( eye_lt = ConstantDiagLinearOperator( torch.ones(*shape[:-2], 1, dtype=dtype, device=device), diag_shape=shape[-2] ) - task_var_lt = task_var_lt.expand(*shape[:-2], *task_var_lt.matrix_shape) + task_var_lt = task_var_lt.expand(*shape[:-2], *task_var_lt.matrix_shape) # pyre-ignore[6] # to add the latent noise we exploit the fact that # I \kron D_T + \sigma^2 I_{NT} = I \kron (D_T + \sigma^2 I) @@ -148,7 +151,7 @@ def _shaped_noise_covar( return covar_kron_lt - def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> base_distributions.Normal: + def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> Normal: noise = self._shaped_noise_covar(function_samples.shape, *params, **kwargs).diagonal(dim1=-1, dim2=-2) noise = noise.reshape(*noise.shape[:-1], *function_samples.shape[-2:]) return base_distributions.Independent(base_distributions.Normal(function_samples, noise.sqrt()), 1) @@ -166,6 +169,9 @@ class MultitaskGaussianLikelihood(_MultitaskGaussianLikelihoodBase): .. note:: At least one of :attr:`has_global_noise` or :attr:`has_task_noise` should be specified. + .. note:: + MultittaskGaussianLikelihood has an analytic marginal distribution. + :param num_tasks: Number of tasks. :param noise_covar: A model for the noise covariance. This can be a simple homoskedastic noise model, or a GP that is to be fitted on the observed measurement errors. @@ -177,20 +183,24 @@ class MultitaskGaussianLikelihood(_MultitaskGaussianLikelihoodBase): :param has_global_noise: Whether to include a :math:`\sigma^2 \mathbf I_{nt}` term in the noise model. :param has_task_noise: Whether to include task-specific noise terms, which add :math:`\mathbf I_n \otimes \mathbf D_T` into the noise model. + + :ivar torch.Tensor task_noise_covar: The inter-task noise covariance matrix + :ivar torch.Tensor task_noises: (Optional) task specific noise variances (added onto the `task_noise_covar`) + :ivar torch.Tensor noise: (Optional) global noise variance (added onto the `task_noise_covar`) """ def __init__( self, - num_tasks, - rank=0, - task_prior=None, - batch_shape=torch.Size(), - noise_prior=None, - noise_constraint=None, - has_global_noise=True, - has_task_noise=True, - ): - super(Likelihood, self).__init__() + num_tasks: int, + rank: int = 0, + batch_shape: torch.Size = torch.Size(), + task_prior: Optional[Prior] = None, + noise_prior: Optional[Prior] = None, + noise_constraint: Optional[Interval] = None, + has_global_noise: bool = True, + has_task_noise: bool = True, + ) -> None: + super(Likelihood, self).__init__() # pyre-ignore[20] if noise_constraint is None: noise_constraint = GreaterThan(1e-4) @@ -230,31 +240,31 @@ def __init__( self.has_task_noise = has_task_noise @property - def noise(self) -> Tensor: + def noise(self) -> Optional[Tensor]: return self.raw_noise_constraint.transform(self.raw_noise) @noise.setter - def noise(self, value: Union[float, Tensor]): + def noise(self, value: Union[float, Tensor]) -> None: self._set_noise(value) @property - def task_noises(self) -> Tensor: + def task_noises(self) -> Optional[Tensor]: if self.rank == 0: return self.raw_task_noises_constraint.transform(self.raw_task_noises) else: raise AttributeError("Cannot set diagonal task noises when covariance has ", self.rank, ">0") @task_noises.setter - def task_noises(self, value: Union[float, Tensor]): + def task_noises(self, value: Union[float, Tensor]) -> None: if self.rank == 0: self._set_task_noises(value) else: raise AttributeError("Cannot set diagonal task noises when covariance has ", self.rank, ">0") - def _set_noise(self, value: Union[float, Tensor]): + def _set_noise(self, value: Union[float, Tensor]) -> None: self.initialize(raw_noise=self.raw_noise_constraint.inverse_transform(value)) - def _set_task_noises(self, value: Union[float, Tensor]): + def _set_task_noises(self, value: Union[float, Tensor]) -> None: self.initialize(raw_task_noises=self.raw_task_noises_constraint.inverse_transform(value)) @property @@ -265,7 +275,7 @@ def task_noise_covar(self) -> Tensor: raise AttributeError("Cannot retrieve task noises when covariance is diagonal.") @task_noise_covar.setter - def task_noise_covar(self, value: Union[float, Tensor]): + def task_noise_covar(self, value: Tensor) -> None: # internally uses a pivoted cholesky decomposition to construct a low rank # approximation of the covariance if self.rank > 0: @@ -277,5 +287,13 @@ def task_noise_covar(self, value: Union[float, Tensor]): def _eval_covar_matrix(self) -> Tensor: covar_factor = self.task_noise_covar_factor noise = self.noise - D = noise * torch.eye(self.num_tasks, dtype=noise.dtype, device=noise.device) + D = noise * torch.eye(self.num_tasks, dtype=noise.dtype, device=noise.device) # pyre-fixme[16] return covar_factor.matmul(covar_factor.transpose(-1, -2)) + D + + def marginal( + self, function_dist: MultitaskMultivariateNormal, *args: Any, **kwargs: Any + ) -> MultitaskMultivariateNormal: + """ + :return: Analytic marginal :math:`p(\mathbf y)`. + """ + return super().marginal(function_dist, *args, **kwargs) diff --git a/gpytorch/likelihoods/noise_models.py b/gpytorch/likelihoods/noise_models.py index 1f9afc782..e74de1f06 100644 --- a/gpytorch/likelihoods/noise_models.py +++ b/gpytorch/likelihoods/noise_models.py @@ -1,10 +1,10 @@ #!/usr/bin/env python3 import warnings -from typing import Any, Optional +from typing import Any, Optional, Union import torch -from linear_operator.operators import ConstantDiagLinearOperator, DiagLinearOperator, ZeroLinearOperator +from linear_operator.operators import ConstantDiagLinearOperator, DiagLinearOperator, LinearOperator, ZeroLinearOperator from torch import Tensor from torch.nn import Parameter @@ -16,7 +16,11 @@ class Noise(Module): - pass + def __call__( + self, *params: Any, shape: Optional[torch.Size] = None, **kwargs: Any + ) -> Union[Tensor, LinearOperator]: + # For corredct typing + return super().__call__(*params, shape=shape, **kwargs) class _HomoskedasticNoiseBase(Noise): @@ -167,3 +171,9 @@ def forward( def _apply(self, fn): self.noise = fn(self.noise) return super(FixedGaussianNoise, self)._apply(fn) + + def __call__( + self, *params: Any, shape: Optional[torch.Size] = None, **kwargs: Any + ) -> Union[Tensor, LinearOperator]: + # For corredct typing + return super().__call__(*params, shape=shape, **kwargs) diff --git a/gpytorch/likelihoods/softmax_likelihood.py b/gpytorch/likelihoods/softmax_likelihood.py index 882382a6d..fa16db253 100644 --- a/gpytorch/likelihoods/softmax_likelihood.py +++ b/gpytorch/likelihoods/softmax_likelihood.py @@ -1,10 +1,14 @@ #!/usr/bin/env python3 import warnings +from typing import Any, Optional, Union import torch +from torch import Tensor +from torch.distributions import Categorical, Distribution -from ..distributions import base_distributions, Distribution, MultitaskMultivariateNormal +from ..distributions import base_distributions, MultitaskMultivariateNormal +from ..priors import Prior from .likelihood import Likelihood @@ -17,23 +21,30 @@ class SoftmaxLikelihood(Likelihood): :math:`\mathbf W` is a set of linear mixing weights applied to the latent functions :math:`\mathbf f`. - :param int num_features: Dimensionality of latent function :math:`\mathbf f`. - :param int num_classes: Number of classes. - :param bool mixing_weights: (Default: `True`) Whether to learn a linear mixing weight :math:`\mathbf W` applied to + :param num_features: Dimensionality of latent function :math:`\mathbf f`. + :param num_classes: Number of classes. + :param mixing_weights: (Default: `True`) Whether to learn a linear mixing weight :math:`\mathbf W` applied to the latent function :math:`\mathbf f`. If `False`, then :math:`\mathbf W = \mathbf I`. :param mixing_weights_prior: Prior to use over the mixing weights :math:`\mathbf W`. - :type mixing_weights_prior: ~gpytorch.priors.Prior, optional + + :ivar torch.Tensor mixing_weights: (Optional) mixing weights. """ - def __init__(self, num_features=None, num_classes=None, mixing_weights=True, mixing_weights_prior=None): + def __init__( + self, + num_features: Optional[int] = None, + num_classes: int = None, # pyre-fixme[9] + mixing_weights: bool = True, + mixing_weights_prior: Optional[Prior] = None, + ) -> None: super().__init__() if num_classes is None: raise ValueError("num_classes is required") self.num_classes = num_classes - if mixing_weights: - self.num_features = num_features + if mixing_weights is not None: if num_features is None: raise ValueError("num_features is required with mixing weights") + self.num_features: int = num_features self.register_parameter( name="mixing_weights", parameter=torch.nn.Parameter(torch.randn(num_classes, num_features).div_(num_features)), @@ -42,9 +53,9 @@ def __init__(self, num_features=None, num_classes=None, mixing_weights=True, mix self.register_prior("mixing_weights_prior", mixing_weights_prior, "mixing_weights") else: self.num_features = num_classes - self.mixing_weights = None + self.mixing_weights: Optional[torch.nn.Parameter] = None - def forward(self, function_samples, *params, **kwargs): + def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> Categorical: num_data, num_features = function_samples.shape[-2:] # Catch legacy mode @@ -67,12 +78,12 @@ def forward(self, function_samples, *params, **kwargs): res = base_distributions.Categorical(logits=mixed_fs) return res - def __call__(self, function, *params, **kwargs): - if isinstance(function, Distribution) and not isinstance(function, MultitaskMultivariateNormal): + def __call__(self, input: Union[Tensor, MultitaskMultivariateNormal], *args: Any, **kwargs: Any) -> Distribution: + if isinstance(input, Distribution) and not isinstance(input, MultitaskMultivariateNormal): warnings.warn( "The input to SoftmaxLikelihood should be a MultitaskMultivariateNormal (num_data x num_tasks). " "Batch MultivariateNormal inputs (num_tasks x num_data) will be deprectated.", DeprecationWarning, ) - function = MultitaskMultivariateNormal.from_batch_mvn(function) - return super().__call__(function, *params, **kwargs) + input = MultitaskMultivariateNormal.from_batch_mvn(input) + return super().__call__(input, *args, **kwargs) diff --git a/gpytorch/likelihoods/student_t_likelihood.py b/gpytorch/likelihoods/student_t_likelihood.py index 20ae2b0df..0e47dcdb9 100644 --- a/gpytorch/likelihoods/student_t_likelihood.py +++ b/gpytorch/likelihoods/student_t_likelihood.py @@ -1,9 +1,14 @@ #!/usr/bin/env python3 +from typing import Any, Optional + import torch +from torch import Tensor +from torch.distributions import StudentT -from ..constraints import GreaterThan, Positive +from ..constraints import GreaterThan, Interval, Positive from ..distributions import base_distributions +from ..priors import Prior from .likelihood import _OneDimensionalLikelihood @@ -14,15 +19,10 @@ class StudentTLikelihood(_OneDimensionalLikelihood): :math:`\sigma^2` - the noise :param batch_shape: The batch shape of the learned noise parameter (default: []). - :type batch_shape: torch.Size, optional :param noise_prior: Prior for noise parameter :math:`\sigma^2`. - :type noise_prior: ~gpytorch.priors.Prior, optional :param noise_constraint: Constraint for noise parameter :math:`\sigma^2`. - :type noise_constraint: ~gpytorch.constraints.Interval, optional :param deg_free_prior: Prior for deg_free parameter :math:`\nu`. - :type deg_free_prior: ~gpytorch.priors.Prior, optional :param deg_free_constraint: Constraint for deg_free parameter :math:`\nu`. - :type deg_free_constraint: ~gpytorch.constraints.Interval, optional :var torch.Tensor deg_free: :math:`\nu` parameter (degrees of freedom) :var torch.Tensor noise: :math:`\sigma^2` parameter (noise) @@ -31,11 +31,11 @@ class StudentTLikelihood(_OneDimensionalLikelihood): def __init__( self, batch_shape: torch.Size = torch.Size([]), - deg_free_prior=None, - deg_free_constraint=None, - noise_prior=None, - noise_constraint=None, - ): + deg_free_prior: Optional[Prior] = None, + deg_free_constraint: Optional[Interval] = None, + noise_prior: Optional[Prior] = None, + noise_constraint: Optional[Interval] = None, + ) -> None: super().__init__() if deg_free_constraint is None: @@ -61,30 +61,30 @@ def __init__( self.initialize(deg_free=7) @property - def deg_free(self): + def deg_free(self) -> Tensor: return self.raw_deg_free_constraint.transform(self.raw_deg_free) @deg_free.setter - def deg_free(self, value): + def deg_free(self, value: Tensor) -> None: self._set_deg_free(value) - def _set_deg_free(self, value): + def _set_deg_free(self, value: Tensor) -> None: if not torch.is_tensor(value): value = torch.as_tensor(value).to(self.raw_deg_free) self.initialize(raw_deg_free=self.raw_deg_free_constraint.inverse_transform(value)) @property - def noise(self): + def noise(self) -> Tensor: return self.raw_noise_constraint.transform(self.raw_noise) @noise.setter - def noise(self, value): + def noise(self, value: Tensor) -> None: self._set_noise(value) - def _set_noise(self, value): + def _set_noise(self, value: Tensor) -> None: if not torch.is_tensor(value): value = torch.as_tensor(value).to(self.raw_noise) self.initialize(raw_noise=self.raw_noise_constraint.inverse_transform(value)) - def forward(self, function_samples, **kwargs): + def forward(self, function_samples: Tensor, *args: Any, **kwargs: Any) -> StudentT: return base_distributions.StudentT(df=self.deg_free, loc=function_samples, scale=self.noise.sqrt()) diff --git a/gpytorch/module.py b/gpytorch/module.py index d3be8eb33..008c6645a 100644 --- a/gpytorch/module.py +++ b/gpytorch/module.py @@ -27,7 +27,7 @@ def __init__(self): self._register_load_state_dict_pre_hook(self._load_state_hook_ignore_shapes) - def __call__(self, *inputs, **kwargs): + def __call__(self, *inputs, **kwargs) -> Union[Tensor, Distribution, LinearOperator]: outputs = self.forward(*inputs, **kwargs) if isinstance(outputs, list): return [_validate_module_outputs(output) for output in outputs]