improve transformer architecture, remove 3.10 install constraint, add documentation for torch.compile()

This commit is contained in:
robcaulk
2023-05-06 16:12:10 +00:00
parent af139ffbab
commit 3bbb7e38ea
3 changed files with 36 additions and 16 deletions

View File

@@ -395,3 +395,21 @@ Here we create a `PyTorchMLPRegressor` class that implements the `fit` method. T
return dataframe
```
To see a full example, you can refer to the [classifier test strategy class](https://github.com/freqtrade/freqtrade/blob/develop/tests/strategy/strats/freqai_test_classifier.py).
#### Improving performance with `torch.compile()`
Torch provides a `torch.compile()` method that can be used to improve performance for specific GPU hardware. More details can be found [here](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). In brief, you simply wrap your `model` in `torch.compile()`:
```python
model = PyTorchMLPModel(
input_dim=n_features,
output_dim=1,
**self.model_kwargs
)
model.to(self.device)
model = torch.compile(model)
```
Then proceed to use the model as normal. Keep in mind that doing this will remove eager execution, which means errors and tracebacks will not be informative.