|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import time |
| 4 | +from typing import Dict |
| 5 | + |
| 6 | +from requests import Response |
| 7 | +from rich.console import Console |
| 8 | + |
| 9 | +from .config import JSONDict, ValidationConfig |
| 10 | +from .interface import CashflowInterface, TriangleInterface |
| 11 | +from .requester import Requester |
| 12 | +from .triangle import Triangle |
| 13 | + |
| 14 | + |
| 15 | +class CashflowModel(CashflowInterface): |
| 16 | + def __init__( |
| 17 | + self, |
| 18 | + id: str, |
| 19 | + name: str, |
| 20 | + dev_model_name: str, |
| 21 | + tail_model_name: str, |
| 22 | + model_class: str, |
| 23 | + endpoint: str, |
| 24 | + requester: Requester, |
| 25 | + asynchronous: bool = False, |
| 26 | + ) -> None: |
| 27 | + super().__init__(model_class, endpoint, requester, asynchronous) |
| 28 | + |
| 29 | + self._endpoint = endpoint |
| 30 | + self._id = id |
| 31 | + self._name = name |
| 32 | + self._dev_model_name = dev_model_name |
| 33 | + self._tail_model_name = tail_model_name |
| 34 | + self._model_class = model_class |
| 35 | + self._fit_response: Response | None = None |
| 36 | + self._predict_response: Response | None = None |
| 37 | + self._get_response: Response | None = None |
| 38 | + |
| 39 | + id = property(lambda self: self._id) |
| 40 | + name = property(lambda self: self._name) |
| 41 | + dev_model_name = property(lambda self: self._dev_model_name) |
| 42 | + tail_model_name = property(lambda self: self._tail_model_name) |
| 43 | + model_class = property(lambda self: self._model_class) |
| 44 | + endpoint = property(lambda self: self._endpoint) |
| 45 | + fit_response = property(lambda self: self._fit_response) |
| 46 | + predict_response = property(lambda self: self._predict_response) |
| 47 | + get_response = property(lambda self: self._get_response) |
| 48 | + delete_response = property(lambda self: self._delete_response) |
| 49 | + |
| 50 | + @classmethod |
| 51 | + def get( |
| 52 | + cls, |
| 53 | + id: str, |
| 54 | + name: str, |
| 55 | + dev_model_name: str, |
| 56 | + tail_model_name: str, |
| 57 | + model_class: str, |
| 58 | + endpoint: str, |
| 59 | + requester: Requester, |
| 60 | + asynchronous: bool = False, |
| 61 | + ) -> CashflowModel: |
| 62 | + console = Console() |
| 63 | + with console.status("Retrieving...", spinner="bouncingBar") as _: |
| 64 | + console.log(f"Getting model '{name}' with ID '{id}'") |
| 65 | + get_response = requester.get(endpoint, stream=True) |
| 66 | + |
| 67 | + self = cls( |
| 68 | + id, |
| 69 | + name, |
| 70 | + dev_model_name, |
| 71 | + tail_model_name, |
| 72 | + model_class, |
| 73 | + endpoint, |
| 74 | + requester, |
| 75 | + asynchronous, |
| 76 | + ) |
| 77 | + self._get_response = get_response |
| 78 | + return self |
| 79 | + |
| 80 | + @classmethod |
| 81 | + def fit_from_interface( |
| 82 | + cls, |
| 83 | + name: str, |
| 84 | + dev_model_name: str, |
| 85 | + tail_model_name: str, |
| 86 | + model_class: str, |
| 87 | + endpoint: str, |
| 88 | + requester: Requester, |
| 89 | + asynchronous: bool = False, |
| 90 | + ) -> CashflowModel: |
| 91 | + """This method fits a new model and constructs a CashflowModel instance. |
| 92 | + It's intended to be used from the `ModelInterface` class mainly, |
| 93 | + and in the future will likely be superseded by having separate |
| 94 | + `create` and `fit` API endpoints. |
| 95 | + """ |
| 96 | + |
| 97 | + post_data = { |
| 98 | + "development_model_name": dev_model_name, |
| 99 | + "tail_model_name": tail_model_name, |
| 100 | + "name": name, |
| 101 | + "model_config": {}, |
| 102 | + } |
| 103 | + fit_response = requester.post(endpoint, data=post_data) |
| 104 | + id = fit_response.json()["model"]["id"] |
| 105 | + self = cls( |
| 106 | + id=id, |
| 107 | + name=name, |
| 108 | + dev_model_name=dev_model_name, |
| 109 | + tail_model_name=tail_model_name, |
| 110 | + model_class=model_class, |
| 111 | + endpoint=endpoint + f"/{id}", |
| 112 | + requester=requester, |
| 113 | + asynchronous=asynchronous, |
| 114 | + ) |
| 115 | + |
| 116 | + self._fit_response = fit_response |
| 117 | + |
| 118 | + return self |
| 119 | + |
| 120 | + def predict( |
| 121 | + self, |
| 122 | + triangle: str | Triangle, |
| 123 | + config: JSONDict | None = None, |
| 124 | + initial_loss_triangle: Triangle | str | None = None, |
| 125 | + prediction_name: str | None = None, |
| 126 | + timeout: int = 300, |
| 127 | + ) -> Triangle: |
| 128 | + triangle_name = triangle if isinstance(triangle, str) else triangle.name |
| 129 | + config = { |
| 130 | + "triangle_name": triangle_name, |
| 131 | + "predict_config": self.PredictConfig(**(config or {})).__dict__, |
| 132 | + } |
| 133 | + if prediction_name: |
| 134 | + config["prediction_name"] = prediction_name |
| 135 | + |
| 136 | + if isinstance(initial_loss_triangle, Triangle): |
| 137 | + config["predict_config"]["initial_loss_name"] = initial_loss_triangle.name |
| 138 | + elif isinstance(initial_loss_triangle, str): |
| 139 | + config["predict_config"]["initial_loss_name"] = initial_loss_triangle |
| 140 | + |
| 141 | + url = self.endpoint + "/predict" |
| 142 | + self._predict_response = self._requester.post(url, data=config) |
| 143 | + |
| 144 | + if self._asynchronous: |
| 145 | + return self |
| 146 | + |
| 147 | + task_id = self.predict_response.json()["modal_task"]["id"] |
| 148 | + task_response = self._poll_remote_task( |
| 149 | + task_id=task_id, |
| 150 | + task_name=f"Predicting from model '{self.name}' on triangle '{triangle_name}'", |
| 151 | + timeout=timeout, |
| 152 | + ) |
| 153 | + if task_response.get("status") != "success": |
| 154 | + raise ValueError(f"Task failed: {task_response['error']}") |
| 155 | + triangle_id = self.predict_response.json()["predictions"] |
| 156 | + triangle = TriangleInterface( |
| 157 | + host=self.endpoint.replace(f"{self.model_class_slug}/{self.id}", ""), |
| 158 | + requester=self._requester, |
| 159 | + ).get(id=triangle_id) |
| 160 | + return triangle |
| 161 | + |
| 162 | + def delete(self) -> CashflowModel: |
| 163 | + self._delete_response = self._requester.delete(self.endpoint) |
| 164 | + return self |
| 165 | + |
| 166 | + def _poll(self, task_id: str) -> JSONDict: |
| 167 | + endpoint = self.endpoint.replace( |
| 168 | + f"{self.model_class_slug}/{self.id}", f"tasks/{task_id}" |
| 169 | + ) |
| 170 | + return self._requester.get(endpoint) |
| 171 | + |
| 172 | + def _poll_remote_task( |
| 173 | + self, task_id: str, task_name: str = "", timeout: int = 300 |
| 174 | + ) -> dict: |
| 175 | + start = time.time() |
| 176 | + status = ["CREATED"] |
| 177 | + console = Console() |
| 178 | + with console.status("Working...", spinner="bouncingBar") as _: |
| 179 | + while time.time() - start < timeout: |
| 180 | + task = self._poll(task_id).json() |
| 181 | + modal_status = ( |
| 182 | + "FINISHED" if task["task_response"] is not None else "PENDING" |
| 183 | + ) |
| 184 | + status.append(modal_status) |
| 185 | + if status[-1] != status[-2]: |
| 186 | + console.log(f"{task_name}: {status[-1]}") |
| 187 | + if status[-1].lower() == "finished": |
| 188 | + return task["task_response"] |
| 189 | + raise TimeoutError(f"Task '{task}' timed out") |
| 190 | + |
| 191 | + class PredictConfig(ValidationConfig): |
| 192 | + """Cashflow model configuration class. |
| 193 | +
|
| 194 | + Attributes: |
| 195 | + use_bf: Whether or not to use Bornhuetter-Ferguson method to adjust reserve estimates. |
| 196 | + use_reverse_bf: Whether or not to use the Reverse B-F method to adjust reserve |
| 197 | + estimates. |
| 198 | + gamma: Gamma parameter in the Reverse B-F method. |
| 199 | + min_reserve: Minimum reserve amounts as a function of development lag. |
| 200 | + seed: Seed to use for model sampling. Defaults to ``None``, but it is highly recommended |
| 201 | + to set. |
| 202 | + """ |
| 203 | + |
| 204 | + use_bf: bool = True |
| 205 | + use_reverse_bf: bool = True |
| 206 | + gamma: float = 0.7 |
| 207 | + min_reserve: Dict[float, float] | None |
| 208 | + seed: int | None = None |
0 commit comments