docker build . -t ohwi/triton_test && docker run --rm --gpus '"device=0"' ohwi/triton_test테스트는 cuda 12.2 버전이 설치된 A100 40GB에서 진행됐습니다.
테스트는 TransfomerEngine에서 FusedRoPE와 unfused version을 테스트 하기 위해서 사용한 함수에서 디폴트를 seq_length: 4096, hidden_size: 128, rotary_percent: 1.0, batch_size: 16, head_num: 32, margin: 0로 설정하고 하나씩 숫자를 바꿔가면서 진행 했습니다.
비교를 위해 TransfomerEngine에 구현되어 있는 Torch native unfused version, fused version, triton 구현체를 비교했습니다.
RoPE-performance-version-1-seq_length-test:
seq_length Torch Native Transformer Engine (Fused) Triton
0 512.0 2.506752 0.711680 0.713728
1 1024.0 4.720640 1.360896 1.373184
2 2048.0 9.277440 2.646016 2.646016
3 4096.0 18.423807 5.210112 5.197824
hidden_size = 128, 256, 512
RoPE-performance-version-1-hidden_size-test:
hidden_size Torch Native Transformer Engine (Fused) Triton
0 128.0 18.426880 5.208064 5.197824
1 256.0 36.747265 11.131392 9.598976
2 512.0 73.428986 24.171520 18.699265
hidden_size가 커지면 속도 차이가 생기는 점에서 쓰레드 블럭에 관련된 파라미터를 조절할 필요가 보였습니다. V1에서는 1d 쓰레드 블럭을 사용했는데, 전체 head에서 사용되는 연산이 같은 점을 이용해서 2d 쓰레드 블럭으로 확장이 가능해보였습니다.
RoPE-performance-version-1-rotary_percent-test:
rotary_percent Torch Native Transformer Engine (Fused) Triton
0 0.5 14.11584 5.099520 7.843840
1 1.0 18.42688 5.211136 5.197824
rotary_percent가 작으면 속도 많이 느려졌습니다. rotary_percent가 작으면 hidden_size가 작아지는 것과 비슷한 현상이 보이게 될 것이고 같은 문제에서 오는 현상으로 판단했습니다.
RoPE-performance-version-1-batch_size-test:
batch_size Torch Native Transformer Engine (Fused) Triton
0 2.0 2.425856 0.726016 0.720896
1 4.0 4.721664 1.375232 1.376256
2 8.0 9.276416 2.648064 2.649088
3 16.0 18.429953 5.208064 5.197824
RoPE-performance-version-1-head_num-test:
head_num Torch Native Transformer Engine (Fused) Triton
0 8.0 4.721664 1.444864 1.408000
1 16.0 9.275392 2.828288 2.671616
2 32.0 18.427904 5.208064 5.197824
3 64.0 36.743683 10.625024 10.288640
RoPE-performance-version-1-margin-test:
margin Torch Native Transformer Engine (Fused) Triton
0 0.0 18.425856 5.208064 5.198848
1 10.0 18.384895 5.198848 5.188608
2 33.0 18.279425 5.171200 5.158912
3 77.0 18.083839 5.113856 5.105664
V1에서 문제점이라고 판단한 부분을 수정했습니다.
- 1d 쓰레드 블럭 대신 2d 쓰레드 블럭을 사용
head_numxhidden_size를 기준으로 warp 숫자와 block size를 조절
결과:
RoPE-performance-version-2-seq_length-test:
seq_length Torch Native Transformer Engine (Fused) Triton
0 512.0 2.423808 0.711680 0.670720
1 1024.0 4.721664 1.367040 1.295360
2 2048.0 9.278464 2.647040 2.512896
3 4096.0 18.424831 5.233664 4.964352
hidden_size = 128, 256, 512
RoPE-performance-version-2-hidden_size-test:
hidden_size Torch Native Transformer Engine (Fused) Triton
0 128.0 18.429953 5.233664 4.964352
1 256.0 36.751873 11.133440 10.204160
2 512.0 73.436157 24.182783 22.141441
RoPE-performance-version-2-rotary_percent-test:
rotary_percent Torch Native Transformer Engine (Fused) Triton
0 0.5 14.116864 5.09952 4.785152
1 1.0 18.429953 5.23264 4.963328
RoPE-performance-version-2-batch_size-test:
batch_size Torch Native Transformer Engine (Fused) Triton
0 2.0 2.427904 0.722944 0.667648
1 4.0 4.722688 1.378304 1.285120
2 8.0 9.279488 2.649088 2.493440
3 16.0 18.426880 5.233664 4.963328
RoPE-performance-version-2-head_num-test:
head_num Torch Native Transformer Engine (Fused) Triton
0 8.0 4.723712 1.445888 1.314816
1 16.0 9.278464 2.829312 2.468864
2 32.0 18.428928 5.233664 4.964352
3 64.0 36.738045 10.626048 10.165248
RoPE-performance-version-2-margin-test:
margin Torch Native Transformer Engine (Fused) Triton
0 0.0 18.428928 5.234688 4.964352
1 10.0 18.385920 5.223424 4.954624
2 33.0 18.278400 5.193728 4.927488
3 77.0 18.087936 5.139456 4.875264
V1에 비해서 hidden_size가 512인 경우의 속도가 조금 느려졌습니다. 하지만 rotary_percent가 0.5인 경우의 문제점을 개선하고, TransformerEngine의 FusedRoPE와 비슷한 속도의 triton 버전을 만들 수 있었습니다.
============================= test session starts ==============================
platform linux -- Python 3.10.12, pytest-7.4.0, pluggy-1.2.0
rootdir: /workspace
plugins: shard-0.1.2, xdist-3.3.1, flakefinder-1.1.0, rerunfailures-12.0, xdoctest-1.0.2, hypothesis-5.35.1
collected 576 items
Running 576 items in this shard
compare_rope.py ........................................................ [ 9%]
........................................................................ [ 22%]
........................................................................ [ 34%]
........................................................................ [ 47%]
........................................................................ [ 59%]
........................................................................ [ 72%]
........................................................................ [ 84%]
........................................................................ [ 97%]
................ [100%]
=============================== warnings summary ===============================
../usr/local/lib/python3.10/dist-packages/setuptools/__init__.py:9
/usr/local/lib/python3.10/dist-packages/setuptools/__init__.py:9: DeprecationWarning: The distutils package is deprecated and slated for removal in Python 3.12. Use setuptools or check PEP 632 for potential alternatives
import distutils.core
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================= 576 passed, 1 warning in 38.31s ========================











