I would like to use flash attention 2 instead of the default pytorch implementation of attention which offers significant inference speed up (2-10x faster).
According to the transformers docs one can use flash attention 2 with Instinct MI210, MI250 and MI300 AMD GPUS
According to the documentation it is strongly suggested to use this Dockerfile that installs the rocm/flash-attention package. We can migrate the required dockerfile instructions in our blubber file in order to use them in our huggingface image.
There are 2 options for flash attention 2 for ROCm as stated in the official documentation:
- CK flash attention
- triton flash attention
Our efforts will include building flash attention for ROCm for the MI210 (gfx90a architecture) on ml-lab first and after we figure out a working example we will transfer this in a docker image using a debian base image.
| flash attention version | platform | Status |
| CK flash attention 2 | ml-lab | |
| CK flash attention 2 | docker | |
| triton flash attention 2 | ml-lab | |
| triton flash attention 2 | docker | |
