diff --git a/.github/workflows/docker-base-image.yml b/.github/workflows/docker-base-image.yml deleted file mode 100644 index bb79925c..00000000 --- a/.github/workflows/docker-base-image.yml +++ /dev/null @@ -1,68 +0,0 @@ -name: Docker Base Image CI - -on: - push: - branches: [ "base" ] - repository_dispatch: - types: [ build_base ] - -jobs: - build: - runs-on: ubuntu-latest - - permissions: - contents: read - packages: write - - steps: - # Step 1: Checkout the repository - - name: Checkout Code - uses: actions/checkout@v4 - - # Step 2: Log in to GitHub Container Registry - - name: Log in to GitHub Container Registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - # Step 2: Set environemnt - - name: Set environment - env: - GIT_ACCESS_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: | - echo "IMAGE_TAG=torchsim-ci:${GITHUB_SHA}" >> $GITHUB_ENV - echo "GITHUB_SHA=${{github.event.pull_request.head.sha}}" >> $GITHUB_ENV - echo "GITHUB_SHA=${{github.event.pull_request.head.sha}}" - - gem5_response_file=/tmp/releases-gem5-latest.json - curl -s https://api.github.com/repos/PSAL-POSTECH/GEM5/releases/latest > ${gem5_response_file} - GEM5_ASSET_ID=$(jq ".assets[0].id" ${gem5_response_file}) - echo "GEM5_ASSET_ID=$GEM5_ASSET_ID" - echo "GEM5_ASSET_ID=$GEM5_ASSET_ID" >> $GITHUB_ENV - - llvm_response_file=/tmp/releases-gem5-latest.json - curl -s https://api.github.com/repos/PSAL-POSTECH/llvm-project/releases/latest > ${llvm_response_file} - LLVM_ASSET_ID=$(jq ".assets[0].id" ${llvm_response_file}) - echo "LLVM_ASSET_ID=$LLVM_ASSET_ID" - echo "LLVM_ASSET_ID=$LLVM_ASSET_ID" >> $GITHUB_ENV - - spike_response_file=/tmp/releases-spike-latest.json - curl -s https://api.github.com/repos/PSAL-POSTECH/riscv-isa-sim/releases/latest > ${spike_response_file} - SPIKE_ASSET_ID=$(jq ".assets[0].id" ${spike_response_file}) - echo "SPIKE_ASSET_ID=$SPIKE_ASSET_ID" - echo "SPIKE_ASSET_ID=$SPIKE_ASSET_ID" >> $GITHUB_ENV - - # Step 3: Build and Push Docker Image - - name: Build and Push Docker Image - uses: docker/build-push-action@v4 - with: - context: . - file: ./Dockerfile.base - push: true - build-args: | - GEM5_ASSET_ID=${{ env.GEM5_ASSET_ID }} - LLVM_ASSET_ID=${{ env.LLVM_ASSET_ID }} - SPIKE_ASSET_ID=${{ env.SPIKE_ASSET_ID }} - tags: ghcr.io/psal-postech/torchsim_base:latest diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index eba48da2..67140c89 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -3,9 +3,81 @@ name: Docker image CI on: pull_request: branches: [ "master", "develop" ] + workflow_dispatch: + +env: + BASE_IMAGE_REPO: ghcr.io/psal-postech/torchsim_base + # PR: head commit; otherwise workflow_dispatch uses the branch SHA + SOURCE_SHA: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} jobs: + ensure-base: + runs-on: ubuntu-latest + outputs: + base_image: ${{ steps.pin.outputs.base_image }} + permissions: + contents: read + packages: write + + steps: + - name: Checkout Code + uses: actions/checkout@v4 + with: + ref: ${{ env.SOURCE_SHA }} + submodules: recursive + + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: PyTorch base image from manifest + run: | + PYTORCH_IMAGE=$(python3 -c "import json; from pathlib import Path; v=json.loads(Path('thirdparty/github-releases.json').read_text()).get('pytorch_image'); print(v or '')") + if [ -z "$PYTORCH_IMAGE" ]; then echo "thirdparty/github-releases.json: pytorch_image is required" >&2; exit 1; fi + echo "PYTORCH_IMAGE=$PYTORCH_IMAGE" >> "$GITHUB_ENV" + + - name: Thirdparty pin + id: pin + run: | + PIN="$(bash scripts/ci/thirdparty_base_pin.sh)" + echo "pin=${PIN}" >> "$GITHUB_OUTPUT" + echo "base_image=${BASE_IMAGE_REPO}:thirdparty-${PIN}" >> "$GITHUB_OUTPUT" + echo "BASE_IMAGE=${BASE_IMAGE_REPO}:thirdparty-${PIN}" >> "$GITHUB_ENV" + + - name: Check base image exists + id: exists + run: | + if docker manifest inspect "${BASE_IMAGE}" > /dev/null 2>&1; then + echo "ok=true" >> "$GITHUB_OUTPUT" + else + echo "ok=false" >> "$GITHUB_OUTPUT" + fi + + - name: Resolve GitHub release asset IDs + if: steps.exists.outputs.ok != 'true' + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: bash scripts/ci/thirdparty_github_asset_env.sh >> "$GITHUB_ENV" + + - name: Build and push base image (missing pin) + if: steps.exists.outputs.ok != 'true' + uses: docker/build-push-action@v4 + with: + context: . + file: ./Dockerfile.base + push: true + build-args: | + PYTORCH_IMAGE=${{ env.PYTORCH_IMAGE }} + GEM5_ASSET_ID=${{ env.GEM5_ASSET_ID }} + LLVM_ASSET_ID=${{ env.LLVM_ASSET_ID }} + SPIKE_ASSET_ID=${{ env.SPIKE_ASSET_ID }} + tags: ${{ env.BASE_IMAGE }} + build-and-test: + needs: ensure-base runs-on: self-hosted permissions: @@ -13,14 +85,12 @@ jobs: packages: write steps: - # Step 1: Checkout the repository - name: Checkout Code uses: actions/checkout@v4 with: - ref: ${{ github.event.pull_request.head.sha }} + ref: ${{ env.SOURCE_SHA }} submodules: recursive - # Step 2: Log in to GitHub Container Registry - name: Login to GHCR uses: docker/login-action@v3 with: @@ -28,7 +98,6 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - # Step 3: Build and Push Docker Image - name: Build and Push Docker Image uses: docker/build-push-action@v6 with: @@ -36,35 +105,45 @@ jobs: file: ./Dockerfile push: true no-cache: true - tags: ghcr.io/psal-postech/torchsim-test:${{ github.sha }} + build-args: | + BASE_IMAGE=${{ needs.ensure-base.outputs.base_image }} + tags: ghcr.io/psal-postech/torchsim-test:${{ env.SOURCE_SHA }} - # Step 4: Wait for GHCR propagation + # Do not use GITHUB_SHA here: on pull_request it is the merge commit, while the image tag uses SOURCE_SHA (PR head). - name: Wait for GHCR propagation + env: + IMAGE_SHA: ${{ env.SOURCE_SHA }} run: | - for i in {1..30}; do + IMG="ghcr.io/psal-postech/torchsim-test:${IMAGE_SHA}" + echo "Verifying tag matches push: ${IMAGE_SHA}" + for i in $(seq 1 30); do echo "Checking if image exists in GHCR (attempt $i)..." - if docker manifest inspect ghcr.io/psal-postech/torchsim-test:${GITHUB_SHA} > /dev/null 2>&1; then + if docker buildx imagetools inspect "$IMG" > /dev/null 2>&1; then echo "Image is now available in GHCR." exit 0 fi - echo "Image not yet available, retrying in 30 seconds..." + if [ "$i" -eq 1 ]; then + echo "buildx imagetools inspect failed; stderr (first attempt):" + docker buildx imagetools inspect "$IMG" 2>&1 || true + fi + echo "Image not yet available, retrying in 20 seconds..." sleep 20 done echo "Image did not become available in GHCR within expected time." exit 1 - test-pytorchsim-wrapper: + test-pytorchsim-wrapper1: needs: build-and-test uses: ./.github/workflows/pytorchsim_test.yml with: - image_name: ghcr.io/psal-postech/torchsim-test:${{ github.sha }} + image_name: ghcr.io/psal-postech/torchsim-test:${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} vector_lane: 128 spad_size: 128 -# call-test2: -# needs: build-and-test -# uses: ./.github/workflows/pytorchsim_test.yml -# with: -# image_name: ghcr.io/psal-postech/${GITHUB_SHA} -# vector_lane: 8 -# spad_size: 32 \ No newline at end of file + test-pytorchsim-wrapper2: + needs: build-and-test + uses: ./.github/workflows/pytorchsim_test.yml + with: + image_name: ghcr.io/psal-postech/torchsim-test:${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + vector_lane: 32 + spad_size: 32 diff --git a/.github/workflows/docker-tutorial-image.yml b/.github/workflows/docker-tutorial-image.yml index c7d3a2ca..c0d8267d 100644 --- a/.github/workflows/docker-tutorial-image.yml +++ b/.github/workflows/docker-tutorial-image.yml @@ -30,6 +30,6 @@ jobs: uses: docker/build-push-action@v4 with: context: . - file: ./Dockerfile.ksc2025 + file: ./tutorial/jupyterhub/Dockerfile.ksc2025 push: true tags: ghcr.io/psal-postech/torchsim_ksc2025:latest diff --git a/.github/workflows/pytorchsim_test.yml b/.github/workflows/pytorchsim_test.yml index fe8a4a7d..a7613b6e 100644 --- a/.github/workflows/pytorchsim_test.yml +++ b/.github/workflows/pytorchsim_test.yml @@ -31,8 +31,6 @@ jobs: run: | echo "Running test_add.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_add.py @@ -52,8 +50,6 @@ jobs: run: | echo "Running test_transcendental.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_transcendental.py @@ -73,8 +69,6 @@ jobs: run: | echo "Running test_activation.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_activation.py @@ -94,8 +88,6 @@ jobs: run: | echo "Running test_batchnorm.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_batchnorm.py @@ -115,8 +107,6 @@ jobs: run: | echo "Running test_bmm.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_bmm.py @@ -136,8 +126,6 @@ jobs: run: | echo "Running test_cnn.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_cnn.py @@ -157,12 +145,29 @@ jobs: run: | echo "Running test_conv2d.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_conv2d.py + test_cat: + name: Run test_cat.py + runs-on: self-hosted + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Run test_cat.py + run: | + echo "Running test_cat.py" + docker run --rm \ + -e vpu_num_lanes="${{ inputs.vector_lane }}" \ + -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ + ${{ inputs.image_name }} python3 PyTorchSim/tests/test_cat.py + test_matmul: name: Run test_matmul.py runs-on: self-hosted @@ -178,8 +183,6 @@ jobs: run: | echo "Running test_matmul.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_matmul.py @@ -199,8 +202,6 @@ jobs: run: | echo "Running test_reduce.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_reduce.py @@ -220,8 +221,6 @@ jobs: run: | echo "Running test_softmax.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_softmax.py @@ -241,8 +240,6 @@ jobs: run: | echo "Running test_transpose2D.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_transpose2D.py @@ -262,8 +259,6 @@ jobs: run: | echo "Running test_view3D_2D.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_view3D_2D.py @@ -283,8 +278,6 @@ jobs: run: | echo "Running test_layernorm.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_layernorm.py @@ -304,8 +297,6 @@ jobs: run: | echo "Running test_mlp.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_mlp.py @@ -325,8 +316,6 @@ jobs: run: | echo "Running test_resnet.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_resnet.py @@ -335,12 +324,29 @@ jobs: run: | echo "Running test_resnet.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_resnet.py --model_type resnet50 + test_mobilenet: + name: Run test_mobilenet.py + runs-on: self-hosted + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Run test_mobilenet.py + run: | + echo "Running test_mobilenet.py" + docker run --rm \ + -e vpu_num_lanes="${{ inputs.vector_lane }}" \ + -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ + ${{ inputs.image_name }} python3 PyTorchSim/tests/MobileNet/test_mobilenet.py + test_transformer: name: Run test_transformer.py runs-on: self-hosted @@ -356,8 +362,6 @@ jobs: run: | echo "Running test_transformer.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_transformer.py @@ -377,8 +381,6 @@ jobs: run: | echo "Running test_transpose3D.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_transpose3D.py @@ -398,8 +400,6 @@ jobs: run: | echo "Running test_sparsity.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_sparsity.py @@ -419,8 +419,6 @@ jobs: run: | echo "Running test_pool.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_pool.py @@ -440,8 +438,6 @@ jobs: run: | echo "Running test_single_perceptron.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_single_perceptron.py @@ -461,8 +457,6 @@ jobs: run: | echo "Running test_addmm_residual.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/Fusion/test_addmm_residual.py @@ -471,8 +465,6 @@ jobs: run: | echo "Running test_matmul_activation.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/Fusion/test_matmul_activation.py @@ -481,8 +473,6 @@ jobs: run: | echo "Running test_matmul_scalar.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/Fusion/test_matmul_scalar.py @@ -491,8 +481,6 @@ jobs: run: | echo "Running test_matmul_reduction.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/Fusion/test_matmul_reduction.py @@ -501,8 +489,6 @@ jobs: run: | echo "Running test_bmm_reduction.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/Fusion/test_bmm_reduction.py @@ -511,8 +497,6 @@ jobs: run: | echo "Running test_prologue_fusion.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/Fusion/test_prologue_fusion.py @@ -521,8 +505,6 @@ jobs: run: | echo "Running test_transformer_fusion.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/Fusion/test_transformer_fusion.py @@ -531,8 +513,6 @@ jobs: run: | echo "Running test_conv_fusion.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/Fusion/test_conv_fusion.py @@ -552,8 +532,6 @@ jobs: run: | echo "Running test_moe.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/MoE/test_moe.py @@ -573,8 +551,6 @@ jobs: run: | echo "Running test_mistral.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/Mixtral_8x7B/test_attention.py @@ -594,8 +570,6 @@ jobs: run: | echo "Running test_vit.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_vit.py @@ -615,8 +589,6 @@ jobs: run: | echo "Running test_diffusion.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/Diffusion/test_diffusion.py @@ -636,8 +608,6 @@ jobs: run: | echo "Running test_indirect.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_indirect_access.py @@ -657,15 +627,71 @@ jobs: run: | echo "Running test_scheduler.py" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_scheduler.py + test_llama: + name: Run test_llama1&2 + runs-on: self-hosted + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Run test_llama.py + run: | + echo "Running test_llama.py" + docker run --rm \ + -e vpu_num_lanes="${{ inputs.vector_lane }}" \ + -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ + ${{ inputs.image_name }} python3 PyTorchSim/tests/Llama/test_llama.py + + test_yolov5: + name: Run test_yolov5 + runs-on: self-hosted + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Run test_yolov5.py + run: | + echo "Running test_yolov5.py" + docker run --rm \ + -e vpu_num_lanes="${{ inputs.vector_lane }}" \ + -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ + ${{ inputs.image_name }} python3 PyTorchSim/tests/Yolov5/test_yolov5.py + + test_deepseek: + name: Run test_deepseek + runs-on: self-hosted + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Run test_deepseek_v3_base.py + run: | + echo "Running test_deepseek_v3_base.py" + docker run --rm \ + -e vpu_num_lanes="${{ inputs.vector_lane }}" \ + -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ + ${{ inputs.image_name }} python3 PyTorchSim/tests/DeepSeek/test_deepseek_v3_base.py + test_accuracy: name: Run test_accuracy runs-on: self-hosted + if: inputs.vector_lane == 128 steps: - name: Log in to GitHub Container Registry uses: docker/login-action@v3 @@ -674,25 +700,18 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - - name: Prepare volume directory - run: mkdir -p /tmp/torchsim-ci/${GITHUB_SHA} - - name: Run run_cycle.sh run: | echo "Running run_cycle.sh" docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ -e vpu_num_lanes="${{ inputs.vector_lane }}" \ -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} bash -c \ - "cd /workspace && PyTorchSim/experiments/artifact/cycle_validation/run_cycle.sh && \ - cp PyTorchSim/experiments/artifact/cycle_validation/summary_cycle.out /dump/summary_cycle.out" - ls /tmp/torchsim-ci/${GITHUB_SHA} + "cd /workspace && PyTorchSim/experiments/artifact/cycle_validation/run_cycle.sh >/dev/null 2>&1 && cat PyTorchSim/experiments/artifact/cycle_validation/summary_cycle.out" > summary_cycle.out - name: Upload Accuracy Report Artifact uses: actions/upload-artifact@v4 with: name: accuracy-report - path: /tmp/torchsim-ci/${{ github.sha }}/summary_cycle.out + path: summary_cycle.out if-no-files-found: error diff --git a/.github/workflows/tag_release.yml b/.github/workflows/tag_release.yml index 0728a583..f92fc060 100644 --- a/.github/workflows/tag_release.yml +++ b/.github/workflows/tag_release.yml @@ -5,8 +5,80 @@ on: tags: - 'v*' +env: + BASE_IMAGE_REPO: ghcr.io/psal-postech/torchsim_base + jobs: + ensure-base: + runs-on: ubuntu-latest + outputs: + base_image: ${{ steps.pin.outputs.base_image }} + permissions: + contents: read + packages: write + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + repository: PSAL-POSTECH/PyTorchSim + ref: ${{ github.sha }} + submodules: recursive + + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: PyTorch base image from manifest + run: | + PYTORCH_IMAGE=$(python3 -c "import json; from pathlib import Path; v=json.loads(Path('thirdparty/github-releases.json').read_text()).get('pytorch_image'); print(v or '')") + if [ -z "$PYTORCH_IMAGE" ]; then echo "thirdparty/github-releases.json: pytorch_image is required" >&2; exit 1; fi + echo "PYTORCH_IMAGE=$PYTORCH_IMAGE" >> "$GITHUB_ENV" + + - name: Thirdparty pin + id: pin + run: | + PIN="$(bash scripts/ci/thirdparty_base_pin.sh)" + echo "pin=${PIN}" >> "$GITHUB_OUTPUT" + echo "base_image=${BASE_IMAGE_REPO}:thirdparty-${PIN}" >> "$GITHUB_OUTPUT" + echo "BASE_IMAGE=${BASE_IMAGE_REPO}:thirdparty-${PIN}" >> "$GITHUB_ENV" + + - name: Check base image exists + id: exists + run: | + if docker manifest inspect "${BASE_IMAGE}" > /dev/null 2>&1; then + echo "ok=true" >> "$GITHUB_OUTPUT" + else + echo "ok=false" >> "$GITHUB_OUTPUT" + fi + + - name: Resolve GitHub release asset IDs + if: steps.exists.outputs.ok != 'true' + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: bash scripts/ci/thirdparty_github_asset_env.sh >> "$GITHUB_ENV" + + - name: Build and push base image (missing pin) + if: steps.exists.outputs.ok != 'true' + uses: docker/build-push-action@v4 + with: + context: . + file: ./Dockerfile.base + push: true + build-args: | + PYTORCH_IMAGE=${{ env.PYTORCH_IMAGE }} + GEM5_ASSET_ID=${{ env.GEM5_ASSET_ID }} + LLVM_ASSET_ID=${{ env.LLVM_ASSET_ID }} + SPIKE_ASSET_ID=${{ env.SPIKE_ASSET_ID }} + tags: | + ${{ env.BASE_IMAGE }} + ${{ env.BASE_IMAGE_REPO }}:latest + build: + needs: ensure-base runs-on: self-hosted permissions: @@ -42,4 +114,6 @@ jobs: push: true secrets: | GIT_ACCESS_TOKEN=${{ secrets.GIT_ACCESS_TOKEN }} - tags: ghcr.io/psal-postech/${{ env.IMAGE_TAG}} \ No newline at end of file + build-args: | + BASE_IMAGE=${{ needs.ensure-base.outputs.base_image }} + tags: ghcr.io/psal-postech/${{ env.IMAGE_TAG }} diff --git a/.gitignore b/.gitignore index b42d5f6b..3ca1e54b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,6 @@ __pycache__/ TOGSim/build/ .vscode -*.txt *.ipynb_checkpoints output togsim_results/* diff --git a/AsmParser/tog_generator.py b/AsmParser/tog_generator.py index 5f586d99..a12460e3 100644 --- a/AsmParser/tog_generator.py +++ b/AsmParser/tog_generator.py @@ -37,7 +37,7 @@ class tog_generator: StonneTraceCompute= 6 StonneTraceLoad = 7 StonneTraceStore = 8 - def __init__(self, origins="Unknown") -> None: + def __init__(self, origins={"Unknown"}) -> None: self.module_name = "tile_operation_graph" self.module = None self.raw_graph = {} @@ -226,7 +226,7 @@ def generate_tile_graph(self, name="tile_graph", cycle_list=list, x_offset=int, offset = w_offset if is_preload else x_offset iter_node.torchsim_overlapping_cycle = max(iter_node.torchsim_cycle - offset, 0) - origin_info = "_".join(map(str, self.origins)) + origin_info = self.origins if isinstance(self.origins, str) else "_".join(map(str, self.origins)) onnx_node_list = [node.to_onnx() for node in node_list] # Exclude root node dump_onnx_graph(name, onnx_node_list, vector_lane, origin_info, stonneGraph=stonneGraph) diff --git a/Dockerfile b/Dockerfile index 37721940..1c52d32f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,6 @@ # syntax=docker/dockerfile:1.4 -FROM ghcr.io/psal-postech/torchsim_base:latest +ARG BASE_IMAGE=ghcr.io/psal-postech/torchsim_base:latest +FROM ${BASE_IMAGE} # Prepare PyTorchSim project COPY . /workspace/PyTorchSim @@ -9,4 +10,7 @@ RUN cd PyTorchSim/TOGSim && \ cd build && \ conan install .. --build=missing && \ cmake .. && \ - make -j$(nproc) \ No newline at end of file + make -j$(nproc) + +RUN cd PyTorchSim/PyTorchSimDevice && \ + python -m pip install --no-build-isolation -e . \ No newline at end of file diff --git a/Dockerfile.base b/Dockerfile.base index 1ac5e175..05444d41 100644 --- a/Dockerfile.base +++ b/Dockerfile.base @@ -23,7 +23,8 @@ # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-runtime +ARG PYTORCH_IMAGE=pytorch/pytorch:2.2.0-cuda12.1-cudnn8-runtime +FROM ${PYTORCH_IMAGE} # Copied from Gem5 Docker file ENV DEBIAN_FRONTEND=noninteractive @@ -33,7 +34,7 @@ RUN apt -y update && \ python3-dev python-is-python3 libboost-all-dev \ libhdf5-serial-dev python3-pydot libpng-dev libelf-dev pkg-config pip \ python3-venv black libssl-dev libasan5 libubsan1 curl device-tree-compiler wget ninja-build && \ - pip install onnx matplotlib scikit-learn && pip install --user conan==1.56.0 && rm -rf /var/lib/apt/lists/* + pip install onnx matplotlib scikit-learn pydot tabulate flash_attn && pip install --user conan==1.56.0 cmake==3.26.4 && rm -rf /var/lib/apt/lists/* # Download RISC-V tool chain RUN wget https://github.com/riscv-collab/riscv-gnu-toolchain/releases/download/2023.12.14/riscv64-glibc-ubuntu-22.04-llvm-nightly-2023.12.14-nightly.tar.gz && \ @@ -44,6 +45,14 @@ RUN wget https://github.com/riscv-collab/riscv-gnu-toolchain/releases/download/2 # Install torchsim dependency RUN apt install ninja-build && pip install onnx matplotlib && pip install --user conan==1.56.0 && pip install "transformers<4.44" && pip install diffusers==0.34.0 +# FlashAttention +RUN python -m pip install --no-build-isolation flash-attn + +# Extra Python deps for YOLO/vision tests +RUN python -m pip install -U pip setuptools wheel && \ + python -m pip install --no-cache-dir --no-deps ultralytics && \ + python -m pip install --no-cache-dir opencv-python-headless pandas seaborn + ENV RISCV=/workspace/riscv ENV PATH=$RISCV/bin:$PATH @@ -67,9 +76,7 @@ RUN curl -L -H "Accept: application/octet-stream" https://api.github.com/repos/P # Store RISC-V LLVM for TorchSim ENV TORCHSIM_LLVM_PATH=/riscv-llvm/bin -ENV TORCHSIM_LLVM_INCLUDE_PATH=/riscv-llvm/include ENV TORCHSIM_DIR=/workspace/PyTorchSim -ENV LLVM_DIR=/riscv-llvm # Download Spike simulator RUN curl -L -H "Accept: application/octet-stream" https://api.github.com/repos/PSAL-POSTECH/riscv-isa-sim/releases/assets/${SPIKE_ASSET_ID} -o /tmp/spike-release.tar.gz && \ diff --git a/PyTorchSimDevice/CMakeLists.txt b/PyTorchSimDevice/CMakeLists.txt new file mode 100644 index 00000000..2c207ca6 --- /dev/null +++ b/PyTorchSimDevice/CMakeLists.txt @@ -0,0 +1,44 @@ +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) + +project(TORCH_OPENREG CXX C) + +include(GNUInstallDirs) +include(CheckCXXCompilerFlag) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_EXTENSIONS OFF) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_SKIP_BUILD_RPATH FALSE) +set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) +set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE) +set(CMAKE_CXX_VISIBILITY_PRESET hidden) + +set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE) + +if(APPLE) + set(CMAKE_INSTALL_RPATH "@loader_path/lib;@loader_path") +elseif(UNIX) + set(CMAKE_INSTALL_RPATH "$ORIGIN/lib:$ORIGIN") +elseif(WIN32) + set(CMAKE_INSTALL_RPATH "") +endif() +set(CMAKE_INSTALL_LIBDIR lib) +set(CMAKE_INSTALL_MESSAGE NEVER) + +set(Torch_DIR ${PYTORCH_INSTALL_DIR}/share/cmake/Torch) +find_package(Torch REQUIRED) + +if(DEFINED PYTHON_INCLUDE_DIR) + include_directories(${PYTHON_INCLUDE_DIR}) +else() + message(FATAL_ERROR "Cannot find Python directory") +endif() + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +include(${PROJECT_SOURCE_DIR}/cmake/TorchPythonTargets.cmake) + +add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/openreg) +add_subdirectory(${PROJECT_SOURCE_DIR}/csrc) +add_subdirectory(${PROJECT_SOURCE_DIR}/torch_openreg/csrc) diff --git a/PyTorchSimDevice/README.md b/PyTorchSimDevice/README.md new file mode 100644 index 00000000..83ec85b1 --- /dev/null +++ b/PyTorchSimDevice/README.md @@ -0,0 +1,175 @@ +# PyTorch OpenReg + +## Background + +The third-party device integration mechanism based on PrivateUse1 has become the official mainstream method for new backends to integrate with PyTorch. Ensuring the availability of this mechanism is crucial for enriching PyTorch's hardware ecosystem. + +**Note:** + +The goal of `torch_openreg` is **not to implement a fully functional, high-performance PyTorch backend**, but to serve as a **minimalist reference implementation for mechanism verification**. + +### Purpose + +- **Test Backend**: To serve as an in-tree test backend for PrivateUse1, ensuring quality stability through CI/CD. +- **Integration Example**: To serve as a reference example for new backend integration. +- **Integration Documentation**: To provide module-level integration documentation that corresponds with the code. + +### Design Principles + +- **Minimality Principle**: The fundamental goal is to enable/verify all integration paths/mechanisms for a new backend to integrate to PyTorch. All functions follow a "just right" strategy to ensure the correctness of relevant integration capabilities. +- **Authenticity Principle**: To complete the OpenReg integration in the same way a real accelerator backend would integrate with PyTorch. + +## Directory Structure + +```shell +torch_openreg/ +├── CMakeLists.txt +├── csrc +│ ├── aten +│ │ ├── native +│ │ │ ├── Extra.cpp +│ │ │ ├── Minimal.cpp +│ │ │ └── ... +│ │ ├── OpenRegExtra.cpp +│ │ └── OpenRegMinimal.cpp +│ ├── CMakeLists.txt +│ └── runtime +│ ├── OpenRegDeviceAllocator.cpp +│ ├── OpenRegDeviceAllocator.h +│ ├── OpenRegFunctions.cpp +│ ├── OpenRegFunctions.h +│ ├── OpenRegGenerator.cpp +│ ├── OpenRegGenerator.h +│ ├── OpenRegGuard.cpp +│ ├── OpenRegGuard.h +│ ├── OpenRegHooks.cpp +│ ├── OpenRegHooks.h +│ ├── OpenRegHostAllocator.cpp +│ ├── OpenRegHostAllocator.h +│ └── ... +├── pyproject.toml +├── README.md +├── setup.py +├── third_party +│ └── openreg +└── torch_openreg + ├── csrc + │ ├── CMakeLists.txt + │ ├── Module.cpp + │ └── stub.c + ├── __init__.py + └── openreg + ├── __init__.py + ├── meta.py + └── random.py +``` + +**Dependencies**: + +```mermaid +graph LR + A[Python] + B[_C.so] + C[libtorch_bindings.so] + D[libtorch_openreg.so] + E[libopenreg.so] + + A --> B --> C --> D --> E +``` + +There are 4 DSOs in torch_openreg, and the dependencies between them are as follows: + +- `_C.so`: + - **sources**: torch_openreg/csrc/stub.c + - **description**: Python C module entry point. +- `libtorch_bindings.so`: The bridging code between Python and C++ should go here. + - **sources**: torch_openreg/csrc + - **description**: A thin glue layer between Python and C++. +- `libtorch_openreg.so`: All core implementations should go here. + - **sources**: csrc + - **description**: All core functionality, such as device runtime, operators, etc. +- `libopenreg.so`: A DSO that uses the CPU to emulate a CUDA-like device, you can ignore it. + - **sources**: third_party/openreg + - **description**: Provides low-level device functionality similar to libcudart.so. + +**Key Directories**: + +- `csrc/`: Core device implementation, including operator registration, runtime, etc. + - `csrc/aten/`: Operator registration + - `csrc/aten/native/`: Specific operator implementations for the OpenReg device. + - `csrc/aten/OpenRegMinimal.cpp`: The most minimal set of operator implementations (allowing for the creation of Tensors and related operations upon completion). + - `csrc/aten/OpenRegExtra.cpp`: Implementations for other types of operators. + - `csrc/runtime/`: Implementations for Host memory, device memory, Guard, Hooks, etc. +- `third_party/`: A C++ library that simulates a CUDA-like device using the CPU. +- `torch_openreg/`: Python interface implementation (Python code and C++ Bindings). + - `torch_openreg/csrc/`: Python C++ binding code. + - `torch_openreg/openreg/`: Python API. + +## Currently Implemented Features + +### Operator Registration + +- Operator Implementation + + - Register for builtin PyTorch Operators + - `TORCH_LIBRARY_IMPL` form: See `empty.memory_format + - `STUB` form: See `abs_stub` + - Register for custom operators + - Schema Registration: See `custom_abs` + - Kernel Registration: See `custom_abs` + - Fallback Registration for `AutogradPriavateUse1`: See `custom_abs` + - Meta Registration: See `custom_abs` + - `torch.autograd.Function`: See `custom_autograd_fn_aliasing` + - Register for fallback + - Per-operator Fallback: See `sub.Tensor` + - Global Fallback: See `wrapper_cpu_fallback` + +## Installation and Usage + +### Installation + +```python +pip3 install --no-build-isolation -e . # for develop +pip3 install --no-build-isolation . # for install +``` + +### Usage Example + +After installation, you can use the `openreg` device in Python just like any other regular device. + +```python +import torch +import torch_openreg + +if not torch.openreg.is_available(): + print("OpenReg backend is not available in this build.") + exit() + +print("OpenReg backend is available!") + +device = torch.device("openreg") + +x = torch.tensor([[1., 2.], [3., 4.]], device=device) +y = x + 2 +print("Result y:\n", y) +print(f"Device of y: {y.device}") + +z = y.cpu() +print("Result z:\n", z) +print(f"Device of z: {z.device}") +``` + +## Future Plans + +- **Enhance Features**: + - Autoload + - AMP + - Device-agnostic APIs + - Memory Management + - Generator + - Distrubuted + - Custom Tensor&Storage + - ... +- **Improve Tests**: Add more test cases related to the integration mechanism. +- **Improve Documentation**: Add a new chapter on third-party device integration in the `Developer Notes` section of the PyTorch documentation. +- **Real-time Synchronization**: Keep the code and documentation updated iteratively and in sync. diff --git a/PyTorchSimDevice/cmake/TorchPythonTargets.cmake b/PyTorchSimDevice/cmake/TorchPythonTargets.cmake new file mode 100644 index 00000000..b7a807d2 --- /dev/null +++ b/PyTorchSimDevice/cmake/TorchPythonTargets.cmake @@ -0,0 +1,22 @@ +if(WIN32) + set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/torch_python.lib") +elseif(APPLE) + set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/libtorch_python.dylib") +else() + set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/libtorch_python.so") +endif() + +add_library(torch_python SHARED IMPORTED) + +set_target_properties(torch_python PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${PYTORCH_INSTALL_DIR}/include" + INTERFACE_LINK_LIBRARIES "c10;torch_cpu" + IMPORTED_LOCATION "${TORCH_PYTHON_IMPORTED_LOCATION}" +) + +add_library(torch_python_library INTERFACE IMPORTED) + +set_target_properties(torch_python_library PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "\$" + INTERFACE_LINK_LIBRARIES "\$;\$" +) diff --git a/PyTorchSimDevice/csrc/CMakeLists.txt b/PyTorchSimDevice/csrc/CMakeLists.txt new file mode 100644 index 00000000..e2ae2b3f --- /dev/null +++ b/PyTorchSimDevice/csrc/CMakeLists.txt @@ -0,0 +1,16 @@ +set(LIBRARY_NAME torch_openreg) + +file(GLOB_RECURSE SOURCE_FILES + "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" +) + +add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES}) + +target_link_libraries(${LIBRARY_NAME} PRIVATE torch_cpu_library openreg) +target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + +install(TARGETS ${LIBRARY_NAME} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR} +) diff --git a/PyTorchSimDevice/csrc/amp/OpenRegAmp.h b/PyTorchSimDevice/csrc/amp/OpenRegAmp.h new file mode 100644 index 00000000..2f81e9d2 --- /dev/null +++ b/PyTorchSimDevice/csrc/amp/OpenRegAmp.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include + +#include + +namespace c10::openreg { + +OPENREG_EXPORT bool is_amp_enabled(); +OPENREG_EXPORT void set_amp_enabled(bool flag); +OPENREG_EXPORT at::ScalarType get_amp_dtype(); +OPENREG_EXPORT void set_amp_dtype(at::ScalarType dtype); + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/amp/auto_cast_mode.cpp b/PyTorchSimDevice/csrc/amp/auto_cast_mode.cpp new file mode 100644 index 00000000..fd650026 --- /dev/null +++ b/PyTorchSimDevice/csrc/amp/auto_cast_mode.cpp @@ -0,0 +1,28 @@ +#include +#include +#include "OpenRegAmp.h" + +namespace { + bool g_amp_enabled = false; + at::ScalarType g_amp_dtype = at::kFloat; +} + +namespace c10::openreg { + +OPENREG_EXPORT bool is_amp_enabled() { + return g_amp_enabled; +} + +OPENREG_EXPORT void set_amp_enabled(bool flag) { + g_amp_enabled = flag; +} + +OPENREG_EXPORT at::ScalarType get_amp_dtype() { + return g_amp_dtype; +} + +OPENREG_EXPORT void set_amp_dtype(at::ScalarType dtype) { + g_amp_dtype = dtype; +} + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp b/PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp new file mode 100644 index 00000000..f048f878 --- /dev/null +++ b/PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp @@ -0,0 +1,163 @@ +#include "native/Extra.h" + +#include +#include +#include + +#include +#include + +namespace at::openreg { + +namespace { +at::Tensor wrapper_quantize_per_tensor( + const at::Tensor& self, + double scale, + int64_t zero_point, + at::ScalarType dtype) { + return at::native::openreg::quantize_per_tensor( + self, scale, zero_point, dtype); +} + +int64_t wrapper__fused_sdp_choice( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_mask, + double dropout_p, + bool is_causal, + std::optional scale, + bool enable_gqa) { + return at::native::openreg::_fused_sdp_choice( + query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa); +} + +void wrapper_quantize_tensor_per_tensor_affine_stub( + const at::Tensor& rtensor, + at::Tensor& qtensor, + double scale, + int64_t zero_point) { + at::native::openreg::quantize_tensor_per_tensor_affine_stub( + rtensor, qtensor, scale, zero_point); +} + +std::tuple +wrapper_scaled_dot_product_fused_attention_overrideable_backward( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& attn_bias, + std::array grad_input_mask, + const at::Tensor& out, + const at::Tensor& logsumexp, + const at::Tensor& cum_seq_q, + const at::Tensor& cum_seq_k, + int64_t max_q, + int64_t max_k, + double dropout_p, + bool is_causal, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + std::optional scale) { + return at::native::openreg:: + _scaled_dot_product_fused_attention_overrideable_backward( + grad_out, + query, + key, + value, + attn_bias, + grad_input_mask, + out, + logsumexp, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + dropout_p, + is_causal, + philox_seed, + philox_offset, + scale); +} + +at::Tensor wrapper_custom_autograd_fn_returns_self(at::Tensor x) { + return at::native::openreg::custom_autograd_fn_returns_self(x); +} + +at::Tensor wrapper_custom_autograd_fn_aliasing(at::Tensor x) { + return at::native::openreg::custom_autograd_fn_aliasing(x); +} + +at::Tensor& wrapper_abs_out(const at::Tensor& self, at::Tensor& out) { + return at::native::openreg::abs_out(self, out); +} + +void wrapper_abs_stub(at::TensorIteratorBase& iter) { + at::native::openreg::abs_kernel(iter); +} + +at::Tensor wrapper_custom_abs(at::Tensor x) { + return at::native::openreg::custom_abs(x); +} +} // namespace + +using namespace at::native; +// Registration via STUB +// LITERALINCLUDE START: STUB DEFAULT +REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &wrapper_abs_stub); +REGISTER_PRIVATEUSE1_DISPATCH( + quantize_tensor_per_tensor_affine_stub, + &wrapper_quantize_tensor_per_tensor_affine_stub); +REGISTER_PRIVATEUSE1_DISPATCH( + _fused_sdp_choice_stub, + &wrapper__fused_sdp_choice); +// LITERALINCLUDE END: STUB DEFAULT + +// Registration of custom operators +// LITERALINCLUDE START: CUSTOM OPERATOR SCHEMA +TORCH_LIBRARY(openreg, m) { + m.def("custom_abs(Tensor input)-> Tensor"); +} +// LITERALINCLUDE END: CUSTOM OPERATOR SCHEMA + +// LITERALINCLUDE START: CUSTOM OPERATOR DEFAULT +TORCH_LIBRARY_IMPL(openreg, PrivateUse1, m) { + m.impl("custom_abs", &wrapper_custom_abs); +} +// LITERALINCLUDE END: CUSTOM OPERATOR DEFAULT + +// LITERALINCLUDE START: CUSTOM OPERATOR FALLBACK +TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) { + m.fallback(torch::autograd::autogradNotImplementedFallback()); +} +// LITERALINCLUDE END: CUSTOM OPERATOR FALLBACK + +// The rest is for testing purposes +TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { + /* + abs_stub only works if abs.out is also registered with PrivateUse1, because + abs.default is designed to redirect directly to abs.out, which calls + abs_stub. + */ + m.impl("abs.out", &wrapper_abs_out); + m.impl("quantize_per_tensor", &wrapper_quantize_per_tensor); + m.impl("_fused_sdp_choice", &wrapper__fused_sdp_choice); + m.impl( + "_scaled_dot_product_fused_attention_overrideable_backward", + &wrapper_scaled_dot_product_fused_attention_overrideable_backward); +} + +TORCH_LIBRARY_FRAGMENT(openreg, m) { + m.def("custom_autograd_fn_returns_self(Tensor input)-> Tensor"); + m.def("custom_autograd_fn_aliasing(Tensor(a) input)-> Tensor(a)"); +} + +TORCH_LIBRARY_IMPL(openreg, AutogradPrivateUse1, m) { + m.impl( + "custom_autograd_fn_returns_self", + &wrapper_custom_autograd_fn_returns_self); + m.impl("custom_autograd_fn_aliasing", &wrapper_custom_autograd_fn_aliasing); +} + +} // namespace at::openreg diff --git a/PyTorchSimDevice/csrc/aten/OpenRegMinimal.cpp b/PyTorchSimDevice/csrc/aten/OpenRegMinimal.cpp new file mode 100644 index 00000000..21ab3fef --- /dev/null +++ b/PyTorchSimDevice/csrc/aten/OpenRegMinimal.cpp @@ -0,0 +1,161 @@ +#include "native/Minimal.h" + +#include +#include + +#include +#include +#include +#include +#include + +namespace at::openreg { + +namespace { + +// LITERALINCLUDE START: EMPTY.MEMORY_FORMAT WRAPPER +at::Tensor wrapper_empty_memory_format( + c10::IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt) { + return at::native::openreg::empty_memory_format( + size, + dtype_opt, + layout_opt, + device_opt, + pin_memory_opt, + memory_format_opt); +} +// LITERALINCLUDE END: EMPTY.MEMORY_FORMAT WRAPPER + +at::Tensor wrapper_empty_strided( + c10::IntArrayRef size, + c10::IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt) { + return at::native::openreg::empty_strided( + size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt); +} + +at::Tensor wrapper_as_strided( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride, + std::optional storage_offset) { + return at::native::openreg::as_strided(self, size, stride, storage_offset); +} + +const at::Tensor& wrapper_resize_( + const at::Tensor& self, + c10::SymIntArrayRef size, + ::std::optional memory_format) { + return at::native::openreg::resize_(self, size, memory_format); +} + +at::Tensor wrapper__reshape_alias( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride) { + return at::native::openreg::_reshape_alias(self, size, stride); +} + +at::Tensor wrapper__copy_from( + const at::Tensor& self, + const at::Tensor& dst, + bool non_blocking) { + return at::native::openreg::_copy_from(self, dst, non_blocking); +} + +at::Tensor wrapper__copy_from_and_resize( + const at::Tensor& self, + const at::Tensor& dst) { + return at::native::openreg::_copy_from_and_resize(self, dst); +} + +at::Scalar wrapper__local_scalar_densor(const at::Tensor& self) { + return at::native::openreg::_local_scalar_dense(self); +} + +at::Tensor& wrapper_set_source_Tensor_( + at::Tensor& self, + const at::Tensor& source) { + return at::native::openreg::set_source_Tensor_(self, source); +} + +at::Tensor& wrapper_set_source_Storage_(at::Tensor& self, at::Storage source) { + return at::native::openreg::set_source_Storage_(self, source); +} + +at::Tensor& wrapper_set_source_Storage_storage_offsetset_( + at::Tensor& result, + at::Storage storage, + int64_t storage_offset, + c10::IntArrayRef size, + c10::IntArrayRef stride) { + return at::native::openreg::set_source_Storage_storage_offset_( + result, storage, storage_offset, size, stride); +} + +at::Tensor wrapper_view(const at::Tensor& self, c10::SymIntArrayRef size) { + return at::native::openreg::view(self, size); +} + +// LITERALINCLUDE START: FALLBACK WRAPPER +void wrapper_cpu_fallback( + const c10::OperatorHandle& op, + torch::jit::Stack* stack) { + const auto& op_name = op.schema().operator_name(); + + // Generate timestamp in format [YYYY-MM-DD HH:MM:SS.mmm] + auto now = std::chrono::system_clock::now(); + auto time_t = std::chrono::system_clock::to_time_t(now); + auto ms = std::chrono::duration_cast( + now.time_since_epoch()) % 1000; + + std::tm tm_buf; + localtime_r(&time_t, &tm_buf); + + std::ostringstream oss; + oss << std::put_time(&tm_buf, "%Y-%m-%d %H:%M:%S"); + oss << '.' << std::setfill('0') << std::setw(3) << ms.count(); + + std::cerr << "[" << oss.str() << "] [INFO] [PyTorchSimDevice] [Eager Mode] Operator: " << op_name << std::endl; + + at::native::openreg::cpu_fallback(op, stack); +} +// LITERALINCLUDE END: FALLBACK WRAPPER + +} // namespace + +// LITERALINCLUDE START: TORCH_LIBRARY_IMPL DEFAULT +TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { + m.impl("empty.memory_format", wrapper_empty_memory_format); + m.impl("empty_strided", wrapper_empty_strided); + m.impl("as_strided", wrapper_as_strided); + m.impl("resize_", wrapper_resize_); + m.impl("_reshape_alias", wrapper__reshape_alias); + m.impl("_copy_from", wrapper__copy_from); + m.impl("_copy_from_and_resize", wrapper__copy_from_and_resize); + m.impl("_local_scalar_dense", wrapper__local_scalar_densor); + m.impl("set_.source_Tensor", wrapper_set_source_Tensor_); + m.impl("set_.source_Storage", wrapper_set_source_Storage_); + m.impl( + "set_.source_Storage_storage_offset", + wrapper_set_source_Storage_storage_offsetset_); + m.impl("view", wrapper_view); +} +// LITERALINCLUDE END: TORCH_LIBRARY_IMPL DEFAULT + +// LITERALINCLUDE START: FALLBACK GLOBAL +TORCH_LIBRARY_IMPL(_, PrivateUse1, m) { + m.fallback( + torch::CppFunction::makeFromBoxedFunction<&wrapper_cpu_fallback>()); +} +// LITERALINCLUDE END: FALLBACK GLOBAL + +} // namespace at::openreg diff --git a/PyTorchSimDevice/csrc/aten/native/Common.h b/PyTorchSimDevice/csrc/aten/native/Common.h new file mode 100644 index 00000000..c17196d0 --- /dev/null +++ b/PyTorchSimDevice/csrc/aten/native/Common.h @@ -0,0 +1,97 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include + +namespace at::native::openreg { + +class MemoryGuard { + public: + template + explicit MemoryGuard(const Args&... args) { + (find_and_unprotect_tensors(args), ...); + } + + ~MemoryGuard() noexcept { + for (void* ptr : unprotected_pointers_) { + orMemoryProtect(ptr); + } + } + + MemoryGuard(const MemoryGuard&) = delete; + MemoryGuard& operator=(const MemoryGuard&) = delete; + MemoryGuard(MemoryGuard&&) = delete; + MemoryGuard& operator=(MemoryGuard&&) = delete; + + private: + template + void find_and_unprotect_tensors(const T& item) { + if constexpr (std::is_base_of_v) { + unprotect_if_needed(item); + } else if constexpr (std::is_same_v) { + if (item.isTensor()) { + unprotect_if_needed(item.toTensor()); + } else if (item.isTensorList()) { + for (const at::Tensor& tensor : item.toTensorListRef()) { + unprotect_if_needed(tensor); + } + } else if (item.isList()) { + for (const c10::IValue& element : item.toListRef()) { + find_and_unprotect_tensors(element); + } + } else if (item.isGenericDict()) { + for (const auto& [key, value] : item.toGenericDict()) { + find_and_unprotect_tensors(key); + find_and_unprotect_tensors(value); + } + } + } + } + + void unprotect_if_needed(const at::TensorBase& tensor) { + if (!tensor.defined() || !tensor.has_storage()) { + return; + } + + void* ptr = tensor.data_ptr(); + orPointerAttributes attr; + + if (orPointerGetAttributes(&attr, ptr) != orSuccess || + attr.type != orMemoryTypeDevice) { + return; + } + + auto [it, inserted] = unprotected_pointers_.insert(attr.pointer); + if (inserted) { + orMemoryUnprotect(attr.pointer); + } + } + + std::unordered_set unprotected_pointers_; +}; + +} // namespace at::native::openreg diff --git a/PyTorchSimDevice/csrc/aten/native/Extra.cpp b/PyTorchSimDevice/csrc/aten/native/Extra.cpp new file mode 100644 index 00000000..eb76f5d7 --- /dev/null +++ b/PyTorchSimDevice/csrc/aten/native/Extra.cpp @@ -0,0 +1,193 @@ +#include "Extra.h" + +namespace at::native::openreg { + +at::Tensor quantize_per_tensor( + const at::Tensor& self, + double scale, + int64_t zero_point, + at::ScalarType dtype) { + return at::native::quantize_per_tensor(self, scale, zero_point, dtype); +} + +int64_t _fused_sdp_choice( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_mask, + double dropout_p, + bool is_causal, + std::optional scale, + bool enable_gqa) { + + sdp::sdp_params params{query, key, value, attn_mask, dropout_p, is_causal, enable_gqa}; + + // Reject inputs that are fundamentally unsupported (e.g. wrong rank) + if (!sdp::check_tensor_shapes(params, /*debug=*/false)) { + return static_cast(sdp::SDPBackend::error); + } + + // q: (B, Hq, L, E) k/v: (B, H, S, E) + const int64_t Hq = query.size(-3); + const int64_t H = key.size(-3); + const int64_t L = query.size(-2); // query sequence length + const int64_t S = key.size(-2); // key/value sequence length + + // Conditions required by the MLIR FlashSDPA kernel: + // Prefill only : L == S (decode has L == 1, not supported) + // Non-GQA : Hq == H (equal query and KV heads) + // No dropout : template has no dropout implementation + // Dense tensors : no nested tensor support + const bool can_use_mlir_flash = + (L == S) && + (Hq == H) && !enable_gqa && + sdp::check_for_dropout(params, /*debug=*/false) && + sdp::check_nested_tensor(params, /*debug=*/false); + + const bool ctx_flash = at::globalContext().userEnabledFlashSDP(); + const bool ctx_math = at::globalContext().userEnabledMathSDP(); + + if (ctx_flash && can_use_mlir_flash) { + return static_cast(sdp::SDPBackend::overrideable); + } + + return static_cast(sdp::SDPBackend::math); +} + +void quantize_tensor_per_tensor_affine_stub( + const at::Tensor& rtensor, + at::Tensor& qtensor, + double scale, + int64_t zero_point) {} + +std::tuple +_scaled_dot_product_fused_attention_overrideable_backward( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& attn_bias, + std::array grad_input_mask, + const at::Tensor& out, + const at::Tensor& logsumexp, + const at::Tensor& cum_seq_q, + const at::Tensor& cum_seq_k, + int64_t max_q, + int64_t max_k, + double dropout_p, + bool is_causal, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + std::optional scale) { + return std::tuple( + at::empty_like(query), + at::empty_like(key), + at::empty_like(value), + at::empty_like(attn_bias)); +} + +namespace { +struct CustomAutogradFnReturnsSelf + : public torch::autograd::Function { + static at::Tensor forward( + torch::autograd::AutogradContext* ctx, + at::Tensor self) { + return self; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + return {grad_output[0] * 0.5}; + } +}; + +struct CustomAutogradFnAliasing + : public torch::autograd::Function { + static at::Tensor forward( + torch::autograd::AutogradContext* ctx, + at::Tensor self) { + return self.view_symint(self.sym_sizes()); + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + return {grad_output[0] * 0.5}; + } +}; +} // namespace + +at::Tensor custom_autograd_fn_returns_self(at::Tensor x) { + return CustomAutogradFnReturnsSelf::apply(x); +} + +at::Tensor custom_autograd_fn_aliasing(at::Tensor x) { + return CustomAutogradFnAliasing::apply(x); +} + +/* + This implementation is only used to test stub registration, so not all + capabilities are fully supported. + + Current Limitations: + - dtype: Float only + - input tensor: must be contiguous layout +*/ +// LITERALINCLUDE START: STUB ABS +void abs_kernel(at::TensorIteratorBase& iter) { + TORCH_CHECK(iter.ntensors() == 2, "Abs kernel expects 2 tensors"); + TORCH_CHECK( + iter.common_dtype() == at::ScalarType::Float, + "Abs kernel only supports float type"); + + auto& output_tensor = iter.tensor(0); + auto& input_tensor = iter.tensor(1); + + TORCH_CHECK( + input_tensor.sizes() == output_tensor.sizes(), + "Input and output tensor sizes must match."); + + auto abs_loop = [](float* out_ptr, const float* in_ptr, int64_t n) { + for (int64_t i = 0; i < n; ++i) { + out_ptr[i] = std::abs(in_ptr[i]); + } + }; + + MemoryGuard guard(input_tensor, output_tensor); + + if (iter.is_contiguous()) { + abs_loop( + static_cast(iter.data_ptr(0)), + static_cast(iter.data_ptr(1)), + iter.numel()); + } else { + TORCH_CHECK( + input_tensor.is_contiguous(), "Input tensor must be contiguous.") + + auto output = at::empty( + input_tensor.sizes(), + input_tensor.options().memory_format( + input_tensor.suggest_memory_format())); + + MemoryGuard guard(output); + + abs_loop( + static_cast(output.data_ptr()), + static_cast(iter.data_ptr(1)), + iter.numel()); + + output_tensor.copy_(output); + } +} +// LITERALINCLUDE END: STUB ABS + +at::Tensor& abs_out(const at::Tensor& self, at::Tensor& out) { + return at::native::abs_out(self, out); +} + +at::Tensor custom_abs(at::Tensor x) { + return at::abs(x); +} + +} // namespace at::native::openreg diff --git a/PyTorchSimDevice/csrc/aten/native/Extra.h b/PyTorchSimDevice/csrc/aten/native/Extra.h new file mode 100644 index 00000000..f002949a --- /dev/null +++ b/PyTorchSimDevice/csrc/aten/native/Extra.h @@ -0,0 +1,69 @@ +#include "Common.h" + +namespace at::native::openreg { + +at::Tensor quantize_per_tensor( + const at::Tensor& self, + double scale, + int64_t zero_point, + at::ScalarType dtype); +int64_t _fused_sdp_choice( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_mask, + double dropout_p, + bool is_causal, + std::optional scale, + bool enable_gqa); +void quantize_tensor_per_tensor_affine_stub( + const at::Tensor& rtensor, + at::Tensor& qtensor, + double scale, + int64_t zero_point); +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + c10::SymInt, + c10::SymInt, + at::Tensor, + at::Tensor, + at::Tensor> +_scaled_dot_product_fused_attention_overrideable( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_bias, + double dropout_p, + bool is_causal, + bool return_debug_mask, + std::optional scale); +std::tuple +_scaled_dot_product_fused_attention_overrideable_backward( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& attn_bias, + std::array grad_input_mask, + const at::Tensor& out, + const at::Tensor& logsumexp, + const at::Tensor& cum_seq_q, + const at::Tensor& cum_seq_k, + int64_t max_q, + int64_t max_k, + double dropout_p, + bool is_causal, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + std::optional scale); + +at::Tensor custom_autograd_fn_returns_self(at::Tensor x); +at::Tensor custom_autograd_fn_aliasing(at::Tensor x); +at::Tensor& abs_out(const at::Tensor& self, at::Tensor& out); +void abs_kernel(at::TensorIteratorBase& iter); +at::Tensor custom_abs(at::Tensor x); + +} // namespace at::native::openreg diff --git a/PyTorchSimDevice/csrc/aten/native/Minimal.cpp b/PyTorchSimDevice/csrc/aten/native/Minimal.cpp new file mode 100644 index 00000000..8a3263bb --- /dev/null +++ b/PyTorchSimDevice/csrc/aten/native/Minimal.cpp @@ -0,0 +1,185 @@ +#include "Minimal.h" + +#include + +namespace at::native::openreg { + +// LITERALINCLUDE START: EMPTY.MEMORY_FORMAT IMPL +at::Tensor empty_memory_format( + c10::IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt) { + const auto device = c10::device_or_default(device_opt); + const auto dtype = c10::dtype_or_default(dtype_opt); + TORCH_CHECK(device.is_privateuseone()); + TORCH_CHECK( + c10::layout_or_default(layout_opt) == c10::Layout::Strided, + "Non strided layout not supported"); + TORCH_CHECK( + !c10::pinned_memory_or_default(pin_memory_opt), + "Pin memory can only be on CPU"); + const c10::DeviceGuard device_guard(device); + constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1); + auto allocator = at::GetAllocator(at::kPrivateUse1); + return at::detail::empty_generic( + size, allocator, pu1_dks, dtype, memory_format_opt); +} +// LITERALINCLUDE END: EMPTY.MEMORY_FORMAT IMPL + +at::Tensor empty_strided( + c10::IntArrayRef size, + c10::IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt) { + const auto device = c10::device_or_default(device_opt); + const auto dtype = c10::dtype_or_default(dtype_opt); + TORCH_CHECK(device.is_privateuseone()); + TORCH_CHECK( + c10::layout_or_default(layout_opt) == c10::Layout::Strided, + "Non strided layout not supported"); + TORCH_CHECK( + !c10::pinned_memory_or_default(pin_memory_opt), + "Pin memory can only be on CPU"); + const c10::DeviceGuard device_guard(device); + constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1); + auto allocator = at::GetAllocator(at::kPrivateUse1); + return at::detail::empty_strided_generic( + size, stride, allocator, pu1_dks, dtype); +} + +at::Tensor as_strided( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride, + std::optional storage_offset) { + MemoryGuard guard(self); + + return at::cpu::as_strided_symint(self, size, stride, storage_offset); +} + +const at::Tensor& resize_( + const at::Tensor& self, + c10::SymIntArrayRef size, + ::std::optional memory_format) { + return at::native::resize_( + self, C10_AS_INTARRAYREF_SLOW(size), memory_format); +} + +at::Tensor _reshape_alias( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride) { + return at::native::_reshape_alias( + self, C10_AS_INTARRAYREF_SLOW(size), C10_AS_INTARRAYREF_SLOW(stride)); +} + +at::Tensor _copy_from( + const at::Tensor& self, + const at::Tensor& dst, + bool non_blocking) { + TORCH_CHECK(self.defined(), "Source tensor (self) is not defined."); + TORCH_CHECK(dst.defined(), "Destination tensor (dst) is not defined."); + + MemoryGuard guard(self, dst); + + if (self.device() == dst.device()) { + at::Tensor dst_as_cpu = at::from_blob( + dst.data_ptr(), + dst.sizes(), + dst.strides(), + dst.options().device(at::kCPU)); + const at::Tensor self_as_cpu = at::from_blob( + self.data_ptr(), + self.sizes(), + self.strides(), + self.options().device(at::kCPU)); + + at::native::copy_( + const_cast(dst_as_cpu), self_as_cpu, non_blocking); + + } else { + if (self.is_cpu()) { + at::Tensor dst_as_cpu = at::from_blob( + dst.data_ptr(), + dst.sizes(), + dst.strides(), + dst.options().device(at::kCPU)); + + at::native::copy_( + const_cast(dst_as_cpu), self, non_blocking); + + } else { + at::Tensor self_as_cpu = at::from_blob( + self.data_ptr(), + self.sizes(), + self.strides(), + self.options().device(at::kCPU)); + + at::native::copy_( + const_cast(dst), self_as_cpu, non_blocking); + } + } + + return dst; +} + +at::Tensor _copy_from_and_resize( + const at::Tensor& self, + const at::Tensor& dst) { + at::native::resize_(dst, self.sizes(), std::nullopt); + return at::native::copy_(const_cast(dst), self, false); +} + +at::Scalar _local_scalar_dense(const at::Tensor& self) { + MemoryGuard guard(self); + return at::native::_local_scalar_dense_cpu(self); +} + +at::Tensor& set_source_Tensor_(at::Tensor& self, const at::Tensor& source) { + return at::native::set_tensor_(self, source); +} + +at::Tensor& set_source_Storage_(at::Tensor& self, at::Storage source) { + return at::native::set_(self, source); +} + +at::Tensor& set_source_Storage_storage_offset_( + at::Tensor& result, + at::Storage storage, + int64_t storage_offset, + c10::IntArrayRef size, + c10::IntArrayRef stride) { + return at::cpu::set_(result, storage, storage_offset, size, stride); +} + +at::Tensor view(const at::Tensor& self, c10::SymIntArrayRef size) { + MemoryGuard guard(self); + return at::native::view(self, C10_AS_INTARRAYREF_SLOW(size)); +} + +// LITERALINCLUDE START: FALLBACK IMPL +void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + static const std::unordered_set cpu_fallback_blocklist = { + c10::OperatorName("aten::abs", ""), + c10::OperatorName("aten::abs", "out"), + }; + + const auto& op_name = op.schema().operator_name(); + if (cpu_fallback_blocklist.count(op_name)) { + TORCH_CHECK( + false, + "Operator '", + op_name, + "' is not implemented for device openreg."); + } else { + at::native::cpu_fallback(op, stack); + } +} +// LITERALINCLUDE END: FALLBACK IMPL + +} // namespace at::native::openreg diff --git a/PyTorchSimDevice/csrc/aten/native/Minimal.h b/PyTorchSimDevice/csrc/aten/native/Minimal.h new file mode 100644 index 00000000..a2e5cf02 --- /dev/null +++ b/PyTorchSimDevice/csrc/aten/native/Minimal.h @@ -0,0 +1,61 @@ +#include "Common.h" + +namespace at::native::openreg { + +at::Tensor empty_memory_format( + c10::IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt); + +at::Tensor empty_strided( + c10::IntArrayRef size, + c10::IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt); + +at::Tensor as_strided( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride, + std::optional storage_offset); + +const at::Tensor& resize_( + const at::Tensor& self, + c10::SymIntArrayRef size, + ::std::optional memory_format); + +at::Tensor _reshape_alias( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride); + +at::Tensor _copy_from( + const at::Tensor& self, + const at::Tensor& dst, + bool non_blocking); + +at::Tensor _copy_from_and_resize(const at::Tensor& self, const at::Tensor& dst); + +at::Scalar _local_scalar_dense(const at::Tensor& self); + +at::Tensor& set_source_Tensor_(at::Tensor& self, const at::Tensor& source); + +at::Tensor& set_source_Storage_(at::Tensor& self, at::Storage source); + +at::Tensor& set_source_Storage_storage_offset_( + at::Tensor& result, + at::Storage storage, + int64_t storage_offset, + c10::IntArrayRef size, + c10::IntArrayRef stride); + +at::Tensor view(const at::Tensor& self, c10::SymIntArrayRef size); + +void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack); + +} // namespace at::native::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegDeviceAllocator.cpp b/PyTorchSimDevice/csrc/runtime/OpenRegDeviceAllocator.cpp new file mode 100644 index 00000000..3d35b677 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegDeviceAllocator.cpp @@ -0,0 +1,8 @@ +#include "OpenRegDeviceAllocator.h" + +namespace c10::openreg { + +static OpenRegDeviceAllocator global_openreg_alloc; +REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_openreg_alloc); + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegDeviceAllocator.h b/PyTorchSimDevice/csrc/runtime/OpenRegDeviceAllocator.h new file mode 100644 index 00000000..c9aea4a9 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegDeviceAllocator.h @@ -0,0 +1,43 @@ +#include + +#include +#include + +#include + +namespace c10::openreg { +struct OpenRegDeviceAllocator final : at::Allocator { + OpenRegDeviceAllocator() = default; + + static void ReportAndDelete(void* ptr) { + if (!ptr) { + return; + } + orFreeHost(ptr); + } + + at::DataPtr allocate(size_t nbytes) override { + int current_device_index = -1; + orGetDevice(¤t_device_index); + + auto curr_device = + c10::Device(c10::DeviceType::PrivateUse1, current_device_index); + void* data = nullptr; + if (nbytes > 0) { + orMalloc(&data, nbytes); + TORCH_CHECK( + data, "Failed to allocator ", nbytes, " bytes on openreg device."); + } + return {data, data, &ReportAndDelete, curr_device}; + } + + at::DeleterFnPtr raw_deleter() const override { + return &ReportAndDelete; + } + + void copy_data(void* dest, const void* src, std::size_t count) const final { + orMemcpy(dest, src, count, orMemcpyDeviceToDevice); + } +}; + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegEvent.h b/PyTorchSimDevice/csrc/runtime/OpenRegEvent.h new file mode 100644 index 00000000..e869cf0d --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegEvent.h @@ -0,0 +1,146 @@ +#pragma once + +#include + +#include "OpenRegException.h" +#include "OpenRegStream.h" + +namespace c10::openreg { + +struct OpenRegEvent { + OpenRegEvent(bool enable_timing) noexcept : enable_timing_{enable_timing} {} + + ~OpenRegEvent() { + if (is_created_) { + OPENREG_CHECK(orEventDestroy(event_)); + } + } + + OpenRegEvent(const OpenRegEvent&) = delete; + OpenRegEvent& operator=(const OpenRegEvent&) = delete; + + OpenRegEvent(OpenRegEvent&& other) noexcept { + moveHelper(std::move(other)); + } + OpenRegEvent& operator=(OpenRegEvent&& other) noexcept { + if (this != &other) { + moveHelper(std::move(other)); + } + return *this; + } + + operator orEvent_t() const { + return event(); + } + + std::optional device() const { + if (is_created_) { + return at::Device(at::kPrivateUse1, device_index_); + } else { + return std::nullopt; + } + } + + bool isCreated() const { + return is_created_; + } + + DeviceIndex device_index() const { + return device_index_; + } + + orEvent_t event() const { + return event_; + } + + bool query() const { + if (!is_created_) { + return true; + } + + orError_t err = orEventQuery(event_); + if (err == orSuccess) { + return true; + } + + return false; + } + + void record() { + record(getCurrentOpenRegStream()); + } + + void recordOnce(const OpenRegStream& stream) { + if (!was_recorded_) + record(stream); + } + + void record(const OpenRegStream& stream) { + if (!is_created_) { + createEvent(stream.device_index()); + } + + TORCH_CHECK( + device_index_ == stream.device_index(), + "Event device ", + device_index_, + " does not match recording stream's device ", + stream.device_index(), + "."); + + OPENREG_CHECK(orEventRecord(event_, stream)); + was_recorded_ = true; + } + + void block(const OpenRegStream& stream) { + if (is_created_) { + OPENREG_CHECK(orStreamWaitEvent(stream, event_, 0)); + } + } + + float elapsed_time(const OpenRegEvent& other) const { + TORCH_CHECK_VALUE( + !(enable_timing_ & orEventDisableTiming) && + !(other.enable_timing_ & orEventDisableTiming), + "Both events must be created with argument 'enable_timing=True'."); + TORCH_CHECK_VALUE( + is_created_ && other.isCreated(), + "Both events must be recorded before calculating elapsed time."); + TORCH_CHECK( + query() && other.query(), + "Both events must be completed before calculating elapsed time."); + + float time_ms = 0; + OPENREG_CHECK(orEventElapsedTime(&time_ms, event_, other.event_)); + return time_ms; + } + + void synchronize() const { + if (is_created_) { + OPENREG_CHECK(orEventSynchronize(event_)); + } + } + + private: + unsigned int enable_timing_{orEventDisableTiming}; + bool is_created_{false}; + bool was_recorded_{false}; + DeviceIndex device_index_{-1}; + orEvent_t event_{}; + + void createEvent(DeviceIndex device_index) { + device_index_ = device_index; + OPENREG_CHECK(orEventCreateWithFlags(&event_, enable_timing_)); + is_created_ = true; + } + + void moveHelper(OpenRegEvent&& other) { + std::swap(enable_timing_, other.enable_timing_); + std::swap(is_created_, other.is_created_); + std::swap(was_recorded_, other.was_recorded_); + std::swap(device_index_, other.device_index_); + std::swap(event_, other.event_); + } +}; + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegException.cpp b/PyTorchSimDevice/csrc/runtime/OpenRegException.cpp new file mode 100644 index 00000000..09eb09b6 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegException.cpp @@ -0,0 +1,9 @@ +#include "OpenRegException.h" + +void orCheckFail( + const char* func, + const char* file, + uint32_t line, + const char* msg) { + throw ::c10::Error({func, file, line}, msg); +} diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegException.h b/PyTorchSimDevice/csrc/runtime/OpenRegException.h new file mode 100644 index 00000000..16c1ee1c --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegException.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +#include + +void orCheckFail( + const char* func, + const char* file, + uint32_t line, + const char* msg = ""); + +#define OPENREG_CHECK(EXPR, ...) \ + do { \ + const orError_t __err = EXPR; \ + if (__err != orSuccess) { \ + orCheckFail( \ + __func__, __FILE__, static_cast(__LINE__), ##__VA_ARGS__); \ + } \ + } while (0) diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegFunctions.cpp b/PyTorchSimDevice/csrc/runtime/OpenRegFunctions.cpp new file mode 100644 index 00000000..566bacd0 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegFunctions.cpp @@ -0,0 +1,74 @@ +#include + +#include "OpenRegException.h" +#include "OpenRegFunctions.h" + +namespace c10::openreg { + +orError_t GetDeviceCount(int* dev_count) { + return orGetDeviceCount(dev_count); +} + +orError_t GetDevice(c10::DeviceIndex* device) { + int tmp_device = -1; + auto err = orGetDevice(&tmp_device); + *device = static_cast(tmp_device); + return err; +} + +orError_t SetDevice(c10::DeviceIndex device) { + int cur_device = -1; + orGetDevice(&cur_device); + if (device == cur_device) { + return orSuccess; + } + return orSetDevice(device); +} + +int device_count_impl() { + int count = 0; + GetDeviceCount(&count); + return count; +} + +OPENREG_EXPORT c10::DeviceIndex device_count() noexcept { + // initialize number of devices only once + static int count = []() { + try { + auto result = device_count_impl(); + TORCH_INTERNAL_ASSERT( + result <= std::numeric_limits::max(), + "Too many devices, DeviceIndex overflowed"); + return result; + } catch (const c10::Error& ex) { + // We don't want to fail, but still log the warning + // msg() returns the message without the stack trace + TORCH_WARN("Device initialization: ", ex.msg()); + return 0; + } + }(); + return static_cast(count); +} + +OPENREG_EXPORT c10::DeviceIndex current_device() { + c10::DeviceIndex cur_device = -1; + GetDevice(&cur_device); + return cur_device; +} + +OPENREG_EXPORT void set_device(c10::DeviceIndex device) { + SetDevice(device); +} + +OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) { + int current_device = -1; + orGetDevice(¤t_device); + + if (device != current_device) { + orSetDevice(device); + } + + return current_device; +} + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegFunctions.h b/PyTorchSimDevice/csrc/runtime/OpenRegFunctions.h new file mode 100644 index 00000000..c2eb1e80 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegFunctions.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +#include + +#include + +namespace c10::openreg { + +OPENREG_EXPORT c10::DeviceIndex device_count() noexcept; +OPENREG_EXPORT c10::DeviceIndex current_device(); +OPENREG_EXPORT void set_device(c10::DeviceIndex device); + +OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device); + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegGenerator.cpp b/PyTorchSimDevice/csrc/runtime/OpenRegGenerator.cpp new file mode 100644 index 00000000..c2e03f66 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegGenerator.cpp @@ -0,0 +1,28 @@ +#include "OpenRegGenerator.h" + +// Default, global generators, one per device. +static std::vector default_generators; + +namespace c10::openreg { + +const at::Generator& getDefaultOpenRegGenerator(c10::DeviceIndex device_index) { + static bool flag [[maybe_unused]] = []() { + auto deivce_nums = device_count(); + default_generators.resize(deivce_nums); + for (auto i = 0; i < deivce_nums; i++) { + default_generators[i] = at::make_generator(i); + default_generators[i].seed(); + } + return true; + }(); + + c10::DeviceIndex idx = device_index; + if (idx == -1) { + idx = current_device(); + } else { + TORCH_CHECK(idx >= 0 && idx < device_count()); + } + return default_generators[idx]; +} + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegGenerator.h b/PyTorchSimDevice/csrc/runtime/OpenRegGenerator.h new file mode 100644 index 00000000..877a9707 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegGenerator.h @@ -0,0 +1,21 @@ +#include +#include + +#include + +#include "OpenRegFunctions.h" + +namespace c10::openreg { +class OpenRegGeneratorImpl : public at::CPUGeneratorImpl { + public: + OpenRegGeneratorImpl(c10::DeviceIndex device_index) { + device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index); + key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1); + } + ~OpenRegGeneratorImpl() override = default; +}; + +const at::Generator& getDefaultOpenRegGenerator( + c10::DeviceIndex device_index = -1); + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegGuard.cpp b/PyTorchSimDevice/csrc/runtime/OpenRegGuard.cpp new file mode 100644 index 00000000..d50e56e4 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegGuard.cpp @@ -0,0 +1,7 @@ +#include "OpenRegGuard.h" + +namespace c10::openreg { + +C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl); + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegGuard.h b/PyTorchSimDevice/csrc/runtime/OpenRegGuard.h new file mode 100644 index 00000000..f0150fe6 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegGuard.h @@ -0,0 +1,197 @@ +#include +#include + +#include + +#include "OpenRegFunctions.h" + +namespace c10::openreg { + +// Device guard registration +struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface { + static constexpr c10::DeviceType static_type = c10::DeviceType::PrivateUse1; + + OpenRegGuardImpl() = default; + explicit OpenRegGuardImpl(c10::DeviceType t) { + TORCH_INTERNAL_ASSERT(t == static_type); + } + + /** + * Return the type of device managed by this guard implementation. + */ + c10::DeviceType type() const override { + return static_type; + } + + /** + * Set the current device to Device, and return the previous c10::Device. + */ + c10::Device exchangeDevice(c10::Device d) const override { + TORCH_CHECK(d.is_privateuseone()); + + auto old_device_index = ExchangeDevice(d.index()); + return c10::Device(static_type, old_device_index); + } + + /** + * Get the current device. + */ + c10::Device getDevice() const override { + int device_index = current_device(); + return c10::Device(static_type, device_index); + } + + /** + * Set the current device to c10::Device. + */ + void setDevice(c10::Device d) const override { + TORCH_CHECK(d.is_privateuseone()); + + set_device(d.index()); + } + + /** + * Set the current device to c10::Device, without checking for errors + * (so, e.g., this can be called from a destructor). + */ + void uncheckedSetDevice(c10::Device d) const noexcept override { + TORCH_CHECK(d.is_privateuseone()); + + set_device(d.index()); + } + + /** + * Get the current stream for a given device. + */ + c10::Stream getStream(c10::Device d) const noexcept override { + return c10::Stream(c10::Stream::DEFAULT, d); + } + + /** + * Get the default stream for a given device. + */ + c10::Stream getDefaultStream(c10::Device d) const override { + return c10::Stream(c10::Stream::DEFAULT, d); + } + + /** + * Get a stream from the global pool for a given device. + */ + c10::Stream getStreamFromGlobalPool( + c10::Device d, + bool isHighPriority = false) const override { + return c10::Stream(c10::Stream::DEFAULT, d); + } + + /** + * Return a new stream for a given device and priority. The stream will be + * copied and shared around, device backend should be able to correctly handle + * the lifetime of the stream. + */ + c10::Stream getNewStream(c10::Device d, int priority = 0) const override { + return c10::Stream(c10::Stream::DEFAULT, d); + } + + /** + * Set a stream to be the thread local current stream for its device. + * Return the previous stream for that device. You are NOT required + * to set the current device to match the device of this stream. + */ + c10::Stream exchangeStream(c10::Stream s) const noexcept override { + return s; + } + + /** + * Destroys the given event. + */ + void destroyEvent(void* event, const c10::DeviceIndex device_index) + const noexcept override {} + + /** + * Increments the event's version and enqueues a job with this version + * in the stream's work queue. When the stream process that job + * it notifies all streams waiting on / blocked by that version of the + * event to continue and marks that version as recorded. + * */ + void record( + void** event, + const c10::Stream& stream, + const c10::DeviceIndex device_index, + const c10::EventFlag flag) const override { + static int event_id = 1; + + if (!*event) + *event = reinterpret_cast(event_id++); + } + + /** + * Does nothing if the event has not been scheduled to be recorded. + * If the event was previously enqueued to be recorded, a command + * to wait for the version of the event that exists at the time of this call + * is inserted in the stream's work queue. + * When the stream reaches this command it will stop processing + * additional commands until that version of the event is marked as recorded. + */ + void block(void* event, const c10::Stream& stream) const override {} + + /** + * Returns true if (and only if) + * (1) the event has never been scheduled to be recorded + * (2) the current version is marked as recorded. + * Returns false otherwise. + */ + bool queryEvent(void* event) const override { + return true; + } + + /** + * Get the number of devices. WARNING: This is REQUIRED to not raise + * an exception. If there is some sort of problem, e.g., driver error, + * you should report that there are zero available devices. + */ + c10::DeviceIndex deviceCount() const noexcept override { + int device_index = -1; + orGetDeviceCount(&device_index); + return device_index; + } + /** + * Return true if all the work previously enqueued on the stream for + * asynchronous execution has completed running on the device. + */ + bool queryStream(const c10::Stream& stream) const override { + return true; + } + + /** + * Wait (by blocking the calling thread) until all the work previously + * enqueued on the stream has completed running on the device. + */ + void synchronizeStream(const c10::Stream& stream) const override {} + + /** + * Wait (by blocking the calling thread) until all the work previously + * recorded on the event has completed running on the device. + */ + void synchronizeEvent(void* event) const override {} + + /** + * Ensure the caching allocator (if any) is aware that the given DataPtr is + * being used on the given stream, and that it should thus avoid recycling the + * DataPtr until all work on that stream is done. + */ + void recordDataPtrOnStream( + const c10::DataPtr& data_ptr, + const c10::Stream& stream) const override {} + + /** + * Fetch the elapsed time between two recorded events. + */ + double elapsedTime( + void* event1, + void* event2, + const c10::DeviceIndex device_index) const override { + return 1; + } +}; + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegHooks.cpp b/PyTorchSimDevice/csrc/runtime/OpenRegHooks.cpp new file mode 100644 index 00000000..57bc2d9f --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegHooks.cpp @@ -0,0 +1,11 @@ +#include "OpenRegHooks.h" + +namespace c10::openreg { + +static bool register_hook_flag [[maybe_unused]] = []() { + at::RegisterPrivateUse1HooksInterface(new OpenRegHooksInterface()); + + return true; +}(); + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegHooks.h b/PyTorchSimDevice/csrc/runtime/OpenRegHooks.h new file mode 100644 index 00000000..656fba8e --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegHooks.h @@ -0,0 +1,41 @@ +#include +#include + +#include +#include + +#include + +#include "OpenRegGenerator.h" + +namespace c10::openreg { +struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface { + OpenRegHooksInterface() {}; + ~OpenRegHooksInterface() override = default; + + bool hasPrimaryContext(c10::DeviceIndex device_index) const override { + return true; + } + + at::Allocator* getPinnedMemoryAllocator() const override { + return at::getHostAllocator(at::kPrivateUse1); + } + + bool isPinnedPtr(const void* data) const override { + orPointerAttributes attr{}; + orPointerGetAttributes(&attr, data); + + return attr.type == orMemoryTypeHost; + } + + const at::Generator& getDefaultGenerator( + c10::DeviceIndex device_index) const override { + return getDefaultOpenRegGenerator(device_index); + } + + at::Generator getNewGenerator(c10::DeviceIndex device_index) const override { + return at::make_generator(device_index); + } +}; + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegHostAllocator.cpp b/PyTorchSimDevice/csrc/runtime/OpenRegHostAllocator.cpp new file mode 100644 index 00000000..55263803 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegHostAllocator.cpp @@ -0,0 +1,8 @@ +#include "OpenRegHostAllocator.h" + +namespace c10::openreg { + +OpenRegHostAllocator caching_host_allocator; +REGISTER_HOST_ALLOCATOR(at::kPrivateUse1, &caching_host_allocator); + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegHostAllocator.h b/PyTorchSimDevice/csrc/runtime/OpenRegHostAllocator.h new file mode 100644 index 00000000..edef545a --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegHostAllocator.h @@ -0,0 +1,48 @@ +#include + +#include +#include + +#include + +namespace c10::openreg { +struct OpenRegHostAllocator final : at::HostAllocator { + OpenRegHostAllocator() = default; + + static void ReportAndDelete(void* ptr) { + if (!ptr) { + return; + } + orFreeHost(ptr); + } + + at::DataPtr allocate(size_t nbytes) override { + void* data = nullptr; + if (nbytes > 0) { + orMallocHost(&data, nbytes); + TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on host."); + } + return {data, data, &ReportAndDelete, at::Device(at::kCPU)}; + } + + at::DeleterFnPtr raw_deleter() const override { + return &ReportAndDelete; + } + + void copy_data(void* dest, const void* src, std::size_t count) const final { + orMemcpy(dest, src, count, orMemcpyHostToHost); + } + + // ignore + bool record_event(void* ptr, void* ctx, c10::Stream stream) override { + return true; + } + void empty_cache() override {} + at::HostStats get_stats() override { + return at::HostStats(); + } + void reset_accumulated_stats() override {} + void reset_peak_stats() override {} +}; + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegSerialization.cpp b/PyTorchSimDevice/csrc/runtime/OpenRegSerialization.cpp new file mode 100644 index 00000000..43809d60 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegSerialization.cpp @@ -0,0 +1,48 @@ +#include "OpenRegSerialization.h" + +namespace c10::openreg { +struct OpenRegBackendMeta : public c10::BackendMeta { + OpenRegBackendMeta(int version_number, int format_number) + : version_number_(version_number), format_number_(format_number) {} + + int version_number_{-1}; + int format_number_{-1}; +}; + +void for_serialization( + const at::Tensor& t, + std::unordered_map& m) { + auto meta_ptr = t.unsafeGetTensorImpl()->get_backend_meta(); + + if (meta_ptr != nullptr) { + auto o_meta_ptr = dynamic_cast(meta_ptr); + if (o_meta_ptr->version_number_ == 1) { + m["version_number"] = true; + } + if (o_meta_ptr->format_number_ == 29) { + m["format_number"] = true; + } + } +} + +void for_deserialization( + const at::Tensor& t, + std::unordered_map& m) { + int version_number{-1}; + int format_number{-1}; + + if (m.find("version_number") != m.end()) { + version_number = 1; + } + if (m.find("format_number") != m.end()) { + format_number = 29; + } + + c10::intrusive_ptr meta{std::unique_ptr( + new OpenRegBackendMeta(version_number, format_number))}; + t.unsafeGetTensorImpl()->set_backend_meta(meta); +} + +REGISTER_PRIVATEUSE1_SERIALIZATION(&for_serialization, &for_deserialization) + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegSerialization.h b/PyTorchSimDevice/csrc/runtime/OpenRegSerialization.h new file mode 100644 index 00000000..559e92ea --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegSerialization.h @@ -0,0 +1,10 @@ +#include + +#define REGISTER_PRIVATEUSE1_SERIALIZATION( \ + FOR_SERIALIZATION, FOR_DESERIALIZATION) \ + static int register_serialization() { \ + torch::jit::TensorBackendMetaRegistry( \ + c10::DeviceType::PrivateUse1, FOR_SERIALIZATION, FOR_DESERIALIZATION); \ + return 0; \ + } \ + static const int _temp = register_serialization(); diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegStream.cpp b/PyTorchSimDevice/csrc/runtime/OpenRegStream.cpp new file mode 100644 index 00000000..aa6c325d --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegStream.cpp @@ -0,0 +1,253 @@ +#include "OpenRegStream.h" + +#include +#include +#include + +#include +#include +#include +#include + +namespace c10::openreg { + +namespace { + +// Global stream state and constants +static c10::once_flag init_flag; + +static DeviceIndex num_devices = -1; +static constexpr int kStreamsPerPoolBits = 5; +static constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits; +static constexpr int kStreamTypeBits = 2; + +/* + * The stream pools are lazily initialized when the first queue is requested + * for a device. The device flags track the initialization of each device. When + * a queue is requested, the next queue in the pool to be returned in a + * round-robin fashion, see Note [Stream Management]. + */ +static std::deque device_flags; +static std::vector, + c10::openreg::max_compile_time_stream_priorities>> + streams; +static std::deque< + std::array, max_compile_time_stream_priorities>> + priority_counters; + +static thread_local std::unique_ptr current_streams = nullptr; + +/* + * Note [StreamId assignment] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~ + * How do we assign stream IDs? + * + * -- 56 bits -- -- 5 bits -- -- 2 bits -- -- 1 bit -- + * zeros StreamIdIndex StreamIdType Ext/native stream + * ignored for ext ignored for ext + * + * Where StreamIdType: + * 00 = default stream + * 01 = normal stream + * 11 = external stream + * + * For external stream, StreamID is a orStream_t pointer. This means that last + * bit will always be 0. So when constructing StreamId for a native stream we + * set last bit to 1 to distinguish between native and external streams. + * + * StreamId is 64-bit, so we can just rely on regular promotion rules. + * We rely on StreamIdIndex and StreamIdType being non-negative; + */ +using StreamIdIndex = uint8_t; +enum class StreamIdType : uint8_t { + DEFAULT = 0x0, + NORMAL = 0x1, + EXT = 0x3, +}; + +inline std::ostream& operator<<(std::ostream& stream, StreamIdType s) { + switch (s) { + case StreamIdType::DEFAULT: + return stream << "DEFAULT"; + case StreamIdType::NORMAL: + return stream << "NORMAL"; + case StreamIdType::EXT: + return stream << "EXT"; + default: + break; + } + + return stream << static_cast(s); +} + +static inline StreamIdType streamIdType(StreamId s) { + // Externally allocated streams have their id being the orStream_ptr + // so the last bit will be 0 + if (!(s & 1)) { + return StreamIdType(StreamIdType::EXT); + } + + int mask_for_type = (1 << kStreamTypeBits) - 1; + auto st = static_cast((s >> 1) & mask_for_type); + TORCH_CHECK( + st == StreamIdType::DEFAULT || st == StreamIdType::NORMAL, + "invalid StreamId: ", + s); + return st; +} + +static inline size_t streamIdIndex(StreamId s) { + return static_cast( + (s >> (kStreamTypeBits + 1)) & ((1 << kStreamsPerPoolBits) - 1)); +} + +StreamId makeStreamId(StreamIdType st, size_t si) { + if (st == StreamIdType::EXT) { + return static_cast(0); + } + + return (static_cast(si) << (kStreamTypeBits + 1)) | + (static_cast(st) << 1) | 1; +} + +static void initGlobalStreamState() { + num_devices = device_count(); + device_flags.resize(num_devices); + streams.resize(num_devices); + priority_counters.resize(num_devices); +} + +static void initSingleDeviceStream( + int priority, + DeviceIndex device_index, + int i) { + auto& stream = streams[device_index][priority][i]; + + OPENREG_CHECK(orStreamCreateWithPriority(&stream, 0, priority)); + priority_counters[device_index][priority] = 0; +} + +// Creates stream pools for the specified device. It should be call only once. +static void initDeviceStreamState(DeviceIndex device_index) { + for (const auto i : c10::irange(kStreamsPerPool)) { + for (const auto p : c10::irange(max_compile_time_stream_priorities)) { + initSingleDeviceStream(p, device_index, i); + } + } +} + +static void initOpenRegStreamsOnce() { + c10::call_once(init_flag, initGlobalStreamState); + + if (current_streams) { + return; + } + + // Inits current streams (thread local) to the last queue in the "normal + // priority" queue pool. Note: the queue pool have not been initialized yet. + // It will be initialized in initDeviceStreamState for the specified device. + current_streams = std::make_unique(num_devices); + for (const auto i : c10::irange(num_devices)) { + current_streams[i] = makeStreamId(StreamIdType::DEFAULT, 0); + } +} + +static uint32_t get_idx(std::atomic& counter) { + auto raw_idx = counter++; + return raw_idx % kStreamsPerPool; +} + +OpenRegStream OpenRegStreamForId(DeviceIndex device_index, StreamId stream_id) { + return OpenRegStream( + OpenRegStream::UNCHECKED, + Stream( + Stream::UNSAFE, + c10::Device(DeviceType::PrivateUse1, device_index), + stream_id)); +} + +} // anonymous namespace + +// See Note [StreamId assignment] +orStream_t OpenRegStream::stream() const { + c10::DeviceIndex device_index = stream_.device_index(); + StreamId stream_id = stream_.id(); + StreamIdType st = streamIdType(stream_id); + size_t si = streamIdIndex(stream_id); + switch (st) { + // The index 0 stream is default as well. + case StreamIdType::DEFAULT: + case StreamIdType::NORMAL: + return streams[device_index][static_cast(st)][si]; + case StreamIdType::EXT: + return reinterpret_cast(stream_id); + default: + TORCH_CHECK( + false, + "Unrecognized stream ", + stream_, + " (I didn't recognize the stream type, ", + st, + ").", + " Did you manufacture the StreamId yourself? Don't do that;"); + } +} + +// Returns a stream from the requested pool +// Note: when called the first time on a device, this will create the +// stream pools for that device. +OpenRegStream getStreamFromPool(const int priority, DeviceIndex device_index) { + initOpenRegStreamsOnce(); + if (device_index == -1) { + device_index = current_device(); + } + c10::call_once( + device_flags[device_index], initDeviceStreamState, device_index); + auto pri_idx = + std::clamp(priority, 0, max_compile_time_stream_priorities - 1); + const auto idx = get_idx(priority_counters[device_index][pri_idx]); + auto id_type = static_cast(pri_idx); + return OpenRegStreamForId(device_index, makeStreamId(id_type, idx)); +} + +OpenRegStream getStreamFromPool(const bool isHighPriority, DeviceIndex device) { + initOpenRegStreamsOnce(); + int priority = 0; + return getStreamFromPool(priority, device); +} + +OpenRegStream getStreamFromExternal( + orStream_t ext_stream, + DeviceIndex device_index) { + return OpenRegStreamForId( + device_index, reinterpret_cast(ext_stream)); +} + +OpenRegStream getDefaultOpenRegStream(DeviceIndex device_index) { + initOpenRegStreamsOnce(); + if (device_index == -1) { + device_index = current_device(); + } + return OpenRegStreamForId( + device_index, makeStreamId(StreamIdType::DEFAULT, 0)); +} + +OpenRegStream getCurrentOpenRegStream(DeviceIndex device_index) { + initOpenRegStreamsOnce(); + if (device_index == -1) { + device_index = current_device(); + } + return OpenRegStreamForId(device_index, current_streams[device_index]); +} + +void setCurrentOpenRegStream(OpenRegStream stream) { + initOpenRegStreamsOnce(); + current_streams[stream.device_index()] = stream.id(); +} + +std::ostream& operator<<(std::ostream& stream, const OpenRegStream& s) { + return stream << s.unwrap(); +} + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegStream.h b/PyTorchSimDevice/csrc/runtime/OpenRegStream.h new file mode 100644 index 00000000..e1fd0c71 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegStream.h @@ -0,0 +1,162 @@ +#pragma once + +#include + +#include "OpenRegException.h" +#include "OpenRegFunctions.h" + +#include +#include +#include + +namespace c10::openreg { + +static constexpr int max_compile_time_stream_priorities = 1; + +class OpenRegStream { + public: + enum Unchecked { UNCHECKED }; + + explicit OpenRegStream(Stream stream) : stream_(stream) { + TORCH_CHECK(stream_.device_type() == DeviceType::PrivateUse1); + } + + explicit OpenRegStream(Unchecked, Stream stream) : stream_(stream) {} + + bool operator==(const OpenRegStream& other) const noexcept { + return unwrap() == other.unwrap(); + } + + bool operator!=(const OpenRegStream& other) const noexcept { + return unwrap() != other.unwrap(); + } + + operator orStream_t() const { + return stream(); + } + + operator Stream() const { + return unwrap(); + } + + DeviceType device_type() const { + return DeviceType::PrivateUse1; + } + + DeviceIndex device_index() const { + return stream_.device_index(); + } + + Device device() const { + return Device(DeviceType::PrivateUse1, device_index()); + } + + StreamId id() const { + return stream_.id(); + } + + bool query() const { + DeviceGuard guard{stream_.device()}; + + if (orStreamQuery(stream()) == orSuccess) { + return true; + } + + return false; + } + + void synchronize() const { + DeviceGuard guard{stream_.device()}; + OPENREG_CHECK(orStreamSynchronize(stream())); + } + + int priority() const { + DeviceGuard guard{stream_.device()}; + int priority = 0; + OPENREG_CHECK(orStreamGetPriority(stream(), &priority)); + return priority; + } + + orStream_t stream() const; + + Stream unwrap() const { + return stream_; + } + + struct c10::StreamData3 pack3() const { + return stream_.pack3(); + } + + static OpenRegStream unpack3( + StreamId stream_id, + DeviceIndex device_index, + DeviceType device_type) { + return OpenRegStream(Stream::unpack3(stream_id, device_index, device_type)); + } + + private: + Stream stream_; +}; + +/* + * Get a stream from the pool in a round-robin fashion. + * + * You can request a stream from the highest priority pool by setting + * isHighPriority to true for a specific device. + */ +OPENREG_EXPORT OpenRegStream +getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1); + +/* + * Get a stream from the pool in a round-robin fashion. + * + * You can request a stream by setting a priority value for a specific device. + * The priority number lower, the priority higher. + */ +OPENREG_EXPORT OpenRegStream +getStreamFromPool(const int priority, DeviceIndex device = -1); + +/* + * Get a OpenRegStream from a externally allocated one. + * + * This is mainly for interoperability with different libraries where we + * want to operate on a non-torch allocated stream for data exchange or similar + * purposes + */ +OPENREG_EXPORT OpenRegStream +getStreamFromExternal(orStream_t ext_stream, DeviceIndex device_index); + +/* + * Get the default OpenReg stream, for the passed OpenReg device, or for the + * current device if no device index is passed. + */ +OPENREG_EXPORT OpenRegStream +getDefaultOpenRegStream(DeviceIndex device_index = -1); + +/* + * Get the current OpenReg stream, for the passed OpenReg device, or for the + * current device if no device index is passed. + */ +OPENREG_EXPORT OpenRegStream +getCurrentOpenRegStream(DeviceIndex device_index = -1); + +/* + * Set the current stream on the device of the passed in stream to be the passed + * in stream. + */ +OPENREG_EXPORT void setCurrentOpenRegStream(OpenRegStream stream); + +OPENREG_EXPORT std::ostream& operator<<( + std::ostream& stream, + const OpenRegStream& s); + +} // namespace c10::openreg + +namespace std { +template <> +struct hash { + size_t operator()(c10::openreg::OpenRegStream s) const noexcept { + return std::hash{}(s.unwrap()); + } +}; +} // namespace std diff --git a/PyTorchSimDevice/include/Macros.h b/PyTorchSimDevice/include/Macros.h new file mode 100644 index 00000000..c75523c2 --- /dev/null +++ b/PyTorchSimDevice/include/Macros.h @@ -0,0 +1,7 @@ +#pragma once + +#ifdef _WIN32 +#define OPENREG_EXPORT __declspec(dllexport) +#else +#define OPENREG_EXPORT __attribute__((visibility("default"))) +#endif diff --git a/PyTorchSimDevice/pyproject.toml b/PyTorchSimDevice/pyproject.toml new file mode 100644 index 00000000..774fe5cd --- /dev/null +++ b/PyTorchSimDevice/pyproject.toml @@ -0,0 +1,35 @@ +[build-system] +requires = [ + "setuptools", + "wheel", + "torch", # Needed by setup.py for getting include of PyTorch +] + +build-backend = "setuptools.build_meta" + +[project] +name = "torch_openreg" +version = "0.0.1" +description = "A minimal reference implementation of an out-of-tree backend" +readme = "README.md" +requires-python = ">=3.9" +license = { text = "BSD-3-Clause" } +authors = [{ name = "PyTorch Team", email = "packages@pytorch.org" }] +dependencies = [ + "torch", +] +# Add classifiers info for making lint happy +classifiers = [ + "Development Status :: 4 - Beta", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + "Programming Language :: C++", + "Programming Language :: Python :: 3 :: Only", +] + +[project.urls] +Homepage = "https://pytorch.org" +Repository = "https://github.com/pytorch/pytorch" +Documentation = "https://pytorch.org/docs" +Forum = "https://discuss.pytorch.org" diff --git a/PyTorchSimDevice/setup.py b/PyTorchSimDevice/setup.py new file mode 100644 index 00000000..01e2f065 --- /dev/null +++ b/PyTorchSimDevice/setup.py @@ -0,0 +1,148 @@ +import multiprocessing +import os +import platform +import shutil +import subprocess +import sys +import sysconfig +from distutils.command.clean import clean + +from setuptools import Extension, find_packages, setup + + +# Env Variables +IS_DARWIN = platform.system() == "Darwin" +IS_WINDOWS = platform.system() == "Windows" + +BASE_DIR = os.path.dirname(os.path.realpath(__file__)) +RUN_BUILD_DEPS = any(arg in {"clean", "dist_info"} for arg in sys.argv) + + +def make_relative_rpath_args(path): + if IS_DARWIN: + return ["-Wl,-rpath,@loader_path/" + path] + elif IS_WINDOWS: + return [] + else: + return ["-Wl,-rpath,$ORIGIN/" + path] + + +def get_pytorch_dir(): + os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" + import torch + + return os.path.dirname(os.path.realpath(torch.__file__)) + + +def build_deps(): + build_dir = os.path.join(BASE_DIR, "build") + os.makedirs(build_dir, exist_ok=True) + + cmake_args = [ + "-DCMAKE_INSTALL_PREFIX=" + + os.path.realpath(os.path.join(BASE_DIR, "torch_openreg")), + "-DPYTHON_INCLUDE_DIR=" + sysconfig.get_paths().get("include"), + "-DPYTORCH_INSTALL_DIR=" + get_pytorch_dir(), + ] + + subprocess.check_call( + ["cmake", BASE_DIR] + cmake_args, cwd=build_dir, env=os.environ + ) + + build_args = [ + "--build", + ".", + "--target", + "install", + "--config", # For multi-config generators + "Release", + "--", + ] + + if IS_WINDOWS: + build_args += ["/m:" + str(multiprocessing.cpu_count())] + else: + build_args += ["-j", str(multiprocessing.cpu_count())] + + command = ["cmake"] + build_args + subprocess.check_call(command, cwd=build_dir, env=os.environ) + + +class BuildClean(clean): + def run(self): + for i in ["build", "install", "torch_openreg/lib"]: + dirs = os.path.join(BASE_DIR, i) + if os.path.exists(dirs) and os.path.isdir(dirs): + shutil.rmtree(dirs) + + for dirpath, _, filenames in os.walk(os.path.join(BASE_DIR, "torch_openreg")): + for filename in filenames: + if filename.endswith(".so"): + os.remove(os.path.join(dirpath, filename)) + + +def main(): + if not RUN_BUILD_DEPS: + build_deps() + + if IS_WINDOWS: + # /NODEFAULTLIB makes sure we only link to DLL runtime + # and matches the flags set for protobuf and ONNX + extra_link_args: list[str] = ["/NODEFAULTLIB:LIBCMT.LIB"] + [ + *make_relative_rpath_args("lib") + ] + # /MD links against DLL runtime + # and matches the flags set for protobuf and ONNX + # /EHsc is about standard C++ exception handling + extra_compile_args: list[str] = ["/MD", "/FS", "/EHsc"] + else: + extra_link_args = [*make_relative_rpath_args("lib")] + extra_compile_args = [ + "-Wall", + "-Wextra", + "-Wno-strict-overflow", + "-Wno-unused-parameter", + "-Wno-missing-field-initializers", + "-Wno-unknown-pragmas", + "-fno-strict-aliasing", + ] + + ext_modules = [ + Extension( + name="torch_openreg._C", + sources=["torch_openreg/csrc/stub.c"], + language="c", + extra_compile_args=extra_compile_args, + libraries=["torch_bindings"], + library_dirs=[os.path.join(BASE_DIR, "torch_openreg/lib")], + extra_link_args=extra_link_args, + ) + ] + + package_data = { + "torch_openreg": [ + "lib/*.so*", + "lib/*.dylib*", + "lib/*.dll", + "lib/*.lib", + ] + } + + setup( + packages=find_packages(), + package_data=package_data, + ext_modules=ext_modules, + cmdclass={ + "clean": BuildClean, # type: ignore[misc] + }, + include_package_data=False, + entry_points={ + "torch.backends": [ + "torch_openreg = torch_openreg:_autoload", + ], + }, + ) + + +if __name__ == "__main__": + main() diff --git a/PyTorchSimDevice/third_party/openreg/CMakeLists.txt b/PyTorchSimDevice/third_party/openreg/CMakeLists.txt new file mode 100644 index 00000000..1bde7e00 --- /dev/null +++ b/PyTorchSimDevice/third_party/openreg/CMakeLists.txt @@ -0,0 +1,21 @@ +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) + +project(TORCH_OPENREG CXX C) + + +set(LIBRARY_NAME openreg) +set(LIBRARY_TEST ortests) + +file(GLOB_RECURSE SOURCE_FILES + "${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp" +) + +add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES}) + +target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + +install(TARGETS ${LIBRARY_NAME} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR} +) diff --git a/PyTorchSimDevice/third_party/openreg/README.md b/PyTorchSimDevice/third_party/openreg/README.md new file mode 100644 index 00000000..0cee2c87 --- /dev/null +++ b/PyTorchSimDevice/third_party/openreg/README.md @@ -0,0 +1,151 @@ +# OpenReg: An Accelerator Backend that Simulates CUDA Behavior on a CPU + +## Introduction + +OpenReg is a C++ backend library that simulates the behavior of a CUDA-like device on a CPU. Its core objective is **not to accelerate computation or improve performance**, but rather to **simulate modern CUDA programming, enabling developers to prototype and test in an environment without actual GPU hardware**. The current design principles are as follows: + +* **API Consistency**: Provide an interface consistent with the CUDA Runtime API, allowing upper-level applications (like PyTorch's `PrivateUse1` backend) to switch and test seamlessly. +* **Functional Consistency**: Provide behavior consistent with the CUDA Runtime, such as memory isolation, device context management, etc. +* **Completeness**: Aim to support `PrivateUse1` device integration and safeguard the third-party device integration mechanism, without striving to cover all capabilities of the CUDA Runtime. + +## Directory Structure + +The project's code is organized with a clear structure and separation of responsibilities: + +```text +openreg/ +├── README.md # Comprehensive introduction of OpenReg. +├── CMakeLists.txt # Top-level CMake build script, used to compile and generate libopenreg.so +├── cmake/ +│ └── GTestTargets.cmake # Utils of fetching GoogleTest. +├── include/ +│ ├── openreg.h # Public API header file, external users only need to include this file +│ └── openreg.inl # Public API header file, as an extension of openreg.h, cannot be included separately. +├── example/ +│ └── example.cpp # Example for OpenReg. +├── tests/ +│ ├── event_tests.cpp # Testcases about OpenReg Event. +│ ├── stream_tests.cpp # Testcases about OpenReg Stream. +│ ├── device_tests.cpp # Testcases about OpenReg Device. +│ └── memory_tests.cpp # Testcases about OpenReg Memory. +└── csrc/ + ├── device.cpp # Implementation of device management APIs + ├── memory.cpp # Implementation of memory management APIs + └── stream.cpp # Implementation of stream and event APIs. +``` + +* `CMakeLists.txt`: Responsible for compiling and linking all source files under the `csrc/` directory to generate the final `libopenreg.so` shared library. +* `include`: Defines all externally exposed APIs, data structures, and enums. + * `openreg.h`: Defines all externally exposed C-style APIs. + * `openreg.inl`: Defines all externally exposed C++ APIs. +* `csrc/`: Contains the C++ implementation source code for all core functionalities. + * `device.cpp`: Implements the core functions of device management: device discovery and context management. + * `memory.cpp`: Implements the core functions of memory management: allocation, free, copy and memory protection. + * `stream.cpp`: Implements the core functions of stream and event: creation, destroy, record, synchronization and so on. + +## Implemented APIs + +OpenReg currently provides a set of APIs covering basic memory and device management. + +### Device Management APIs + +| OpenReg | CUDA | Feature Description | +| :------------------------------- | :--------------------------------- | :--------------------------------- | +| `orGetDeviceCount` | `cudaGetDeviceCount` | Get the number of available GPUs | +| `orSetDevice` | `cudaSetDevice` | Set the active GPU | +| `orGetDevice` | `cudaGetDevice` | Get the current GPU | +| `orDeviceSynchronize` | `cudaDeviceSynchronize` | Wait for all GPU tasks to finish | +| `orDeviceGetStreamPriorityRange` | `cudaDeviceGetStreamPriorityRange` | Get the range of stream priorities | + +### Memory Management APIs + +| OpenReg | CUDA | Feature Description | +| :----------------------- | :------------------------- | :---------------------------------------- | +| `orMalloc` | `cudaMalloc` | Allocate device memory | +| `orFree` | `cudaFree` | Free device memory | +| `orMallocHost` | `cudaMallocHost` | Allocate page-locked (Pinned) host memory | +| `orFreeHost` | `cudaFreeHost` | Free page-locked host memory | +| `orMemcpy` | `cudaMemcpy` | Synchronous memory copy | +| `orMemcpyAsyn` | `cudaMemcpyAsyn` | Asynchronous memory copy | +| `orPointerGetAttributes` | `cudaPointerGetAttributes` | Get pointer attributes | + +### Stream APIs + +| OpenReg | CUDA | Feature Description | +| :--------------------------- | :----------------------------- | :------------------------------------- | +| `orStreamCreate` | `cudaStreamCreate` | Create a default-priority stream | +| `orStreamCreateWithPriority` | `cudaStreamCreateWithPriority` | Create a stream with a given priority | +| `orStreamDestroy` | `cudaStreamDestroy` | Destroy a stream | +| `orStreamQuery` | `cudaStreamQuery` | Check if a stream has completed | +| `orStreamSynchronize` | `cudaStreamSynchronize` | Wait for a stream to complete | +| `orStreamWaitEvent` | `cudaStreamWaitEvent` | Make a stream wait for an event | +| `orStreamGetPriority` | `cudaStreamGetPriority` | Get a stream’s priority | + +### Event APIs + +| OpenReg | CUDA | Feature Description | +| :----------------------- | :------------------------- | :---------------------------------- | +| `orEventCreate` | `cudaEventCreate` | Create an event with default flag | +| `orEventCreateWithFlags` | `cudaEventCreateWithFlags` | Create an event with specific flag | +| `orEventDestroy` | `cudaEventDestroy` | Destroy an event | +| `orEventRecord` | `cudaEventRecord` | Record an event in a stream | +| `orEventSynchronize` | `cudaEventSynchronize` | Wait for an event to complete | +| `orEventQuery` | `cudaEventQuery` | Check if an event has completed | +| `orEventElapsedTime` | `cudaEventElapsedTime` | Get time elapsed between two events | + +## Implementation Principles + +### Device Management Principles + +Simulating multiple devices and thread-safe device context switching: + +1. **Device Count**: The total number of simulated devices is defined by the compile-time constant `constexpr int kDeviceCount`. +2. **Device Switching**: Device switching in multi-threaded scenarios is simulated using a **TLS (Thread-Local Storage) global variable**. + +### Memory Management Principles + +Simulating device memory, host memory, and memory copies: + +1. **Allocation**: A page-aligned memory block is allocated using `mmap` + `mprotect` with the permission flag `PROT_NONE`. Read, write, and execute operations on this memory region are all prohibited. +2. **Deallocation**: Memory is freed using `munmap`. +3. **Authorization**: When a legitimate memory access is required, an RAII guard restores the memory permissions to `PROT_READ | PROT_WRITE`. The permissions are automatically reverted to `PROT_NONE` when the scope is exited. + +### Stream&Event Principles + +Simulating creation, release and synchronization for event and steam: + +1. **Event**: Each event is encapsulated as a task function and placed into a stream, which acts as a thread. Upon completion of the task, a flag within the event is modified to simulate the event's status. +2. **Stream**: When each stream is requested, a new thread is created, which sequentially processes each task in the task queue within the stream structure. Tasks can be wrappers around kernel functions or events. +3. **Synchronization**: Synchronization between streams and events is achieved using multithreading, condition variables, and mutexes. + +## Usage Example + +Please refer to [example](example/example.cpp) for example. + +The command to compile example.cpp is as follow: + +```Shell +mkdir build + +pushd build +cmake .. +make -j 32 +popd + +g++ -o out example/example.cpp -L ./build -lopenreg +LD_LIBRARY_PATH=./build ./out +``` + +The output is as follow: + +```Shell +Current environment have 2 devices +Current is 0 device +All tasks have been submitted. +Kernel execution time: 0.238168 ms +Verification PASSED! +``` + +## Next Steps + +The most basic functions of the OpenReg backend are currently supported, and will be dynamically optimized and expanded based on the needs of PyTorch integration. diff --git a/PyTorchSimDevice/third_party/openreg/cmake/GTestTargets.cmake b/PyTorchSimDevice/third_party/openreg/cmake/GTestTargets.cmake new file mode 100644 index 00000000..777fc489 --- /dev/null +++ b/PyTorchSimDevice/third_party/openreg/cmake/GTestTargets.cmake @@ -0,0 +1,12 @@ +set(GTest_REL_PATH "../../../../../../../third_party/googletest") +get_filename_component(GTest_DIR "${CMAKE_CURRENT_LIST_DIR}/${GTest_REL_PATH}" ABSOLUTE) + +if(EXISTS "${GTest_DIR}/CMakeLists.txt") + message(STATUS "Found GTest: ${GTest_DIR}") + + set(BUILD_GMOCK OFF CACHE BOOL "Disable GMock build") + set(INSTALL_GTEST OFF CACHE BOOL "Disable GTest install") + add_subdirectory(${GTest_DIR} "${CMAKE_BINARY_DIR}/gtest") +else() + message(FATAL_ERROR "GTest Not Found") +endif() diff --git a/PyTorchSimDevice/third_party/openreg/csrc/device.cpp b/PyTorchSimDevice/third_party/openreg/csrc/device.cpp new file mode 100644 index 00000000..9643bc59 --- /dev/null +++ b/PyTorchSimDevice/third_party/openreg/csrc/device.cpp @@ -0,0 +1,37 @@ +#include + +namespace { + +// Total device numbers +constexpr int DEVICE_COUNT = 2; +// Current device index +thread_local int gCurrentDevice = 0; + +} // namespace + +orError_t orGetDeviceCount(int* count) { + if (!count) { + return orErrorUnknown; + } + + *count = DEVICE_COUNT; + return orSuccess; +} + +orError_t orGetDevice(int* device) { + if (!device) { + return orErrorUnknown; + } + + *device = gCurrentDevice; + return orSuccess; +} + +orError_t orSetDevice(int device) { + if (device < 0 || device >= DEVICE_COUNT) { + return orErrorUnknown; + } + + gCurrentDevice = device; + return orSuccess; +} diff --git a/PyTorchSimDevice/third_party/openreg/csrc/memory.cpp b/PyTorchSimDevice/third_party/openreg/csrc/memory.cpp new file mode 100644 index 00000000..6f02eeb0 --- /dev/null +++ b/PyTorchSimDevice/third_party/openreg/csrc/memory.cpp @@ -0,0 +1,259 @@ +#include "memory.h" + +#include + +#include +#include + +namespace { + +struct Block { + orMemoryType type = orMemoryType::orMemoryTypeUnmanaged; + int device = -1; + void* pointer = nullptr; + size_t size = 0; + int refcount{0}; +}; + +class MemoryManager { + public: + static MemoryManager& getInstance() { + static MemoryManager instance; + return instance; + } + + orError_t allocate(void** ptr, size_t size, orMemoryType type) { + if (!ptr || size == 0) + return orErrorUnknown; + + std::lock_guard lock(m_mutex); + long page_size = openreg::get_pagesize(); + size_t aligned_size = ((size - 1) / page_size + 1) * page_size; + void* mem = nullptr; + int current_device = -1; + + if (type == orMemoryType::orMemoryTypeDevice) { + orGetDevice(¤t_device); + + mem = openreg::mmap(aligned_size); + if (mem == nullptr) + return orErrorUnknown; + if (openreg::mprotect(mem, aligned_size, F_PROT_NONE) != 0) { + openreg::munmap(mem, aligned_size); + return orErrorUnknown; + } + } else { + if (openreg::alloc(&mem, page_size, aligned_size) != 0) { + return orErrorUnknown; + } + } + + m_registry[mem] = {type, current_device, mem, aligned_size, 0}; + *ptr = mem; + return orSuccess; + } + + orError_t free(void* ptr) { + if (!ptr) + return orSuccess; + + std::lock_guard lock(m_mutex); + auto it = m_registry.find(ptr); + if (it == m_registry.end()) + return orErrorUnknown; + + const auto& info = it->second; + if (info.type == orMemoryType::orMemoryTypeDevice) { + openreg::mprotect(info.pointer, info.size, F_PROT_READ | F_PROT_WRITE); + openreg::munmap(info.pointer, info.size); + } else { + openreg::free(info.pointer); + } + + m_registry.erase(it); + return orSuccess; + } + + orError_t memcpy( + void* dst, + const void* src, + size_t count, + orMemcpyKind kind) { + if (!dst || !src || count == 0) + return orErrorUnknown; + + std::lock_guard lock(m_mutex); + Block* dst_info = getBlockInfoNoLock(dst); + Block* src_info = getBlockInfoNoLock(src); + + switch (kind) { + case orMemcpyHostToDevice: + if ((!dst_info || dst_info->type != orMemoryType::orMemoryTypeDevice) || + (src_info && src_info->type == orMemoryType::orMemoryTypeDevice)) + return orErrorUnknown; + break; + case orMemcpyDeviceToHost: + if ((dst_info && dst_info->type == orMemoryType::orMemoryTypeDevice) || + (!src_info || src_info->type != orMemoryType::orMemoryTypeDevice)) + return orErrorUnknown; + break; + case orMemcpyDeviceToDevice: + if ((!dst_info || dst_info->type != orMemoryType::orMemoryTypeDevice) || + (!src_info || src_info->type != orMemoryType::orMemoryTypeDevice)) + return orErrorUnknown; + break; + case orMemcpyHostToHost: + if ((dst_info && dst_info->type == orMemoryType::orMemoryTypeDevice) || + (src_info && src_info->type == orMemoryType::orMemoryTypeDevice)) + return orErrorUnknown; + break; + } + + unprotectNoLock(dst_info); + unprotectNoLock(src_info); + ::memcpy(dst, src, count); + protectNoLock(dst_info); + protectNoLock(src_info); + + return orSuccess; + } + + orError_t getPointerAttributes( + orPointerAttributes* attributes, + const void* ptr) { + if (!attributes || !ptr) + return orErrorUnknown; + + std ::lock_guard lock(m_mutex); + Block* info = getBlockInfoNoLock(ptr); + + if (!info) { + attributes->type = orMemoryType::orMemoryTypeUnmanaged; + attributes->device = -1; + attributes->pointer = const_cast(ptr); + } else { + attributes->type = info->type; + attributes->device = info->device; + attributes->pointer = info->pointer; + } + + return orSuccess; + } + + orError_t unprotect(void* ptr) { + std::lock_guard lock(m_mutex); + return unprotectNoLock(getBlockInfoNoLock(ptr)); + } + + orError_t protect(void* ptr) { + std::lock_guard lock(m_mutex); + return protectNoLock(getBlockInfoNoLock(ptr)); + } + + private: + MemoryManager() = default; + + orError_t unprotectNoLock(Block* info) { + if (info && info->type == orMemoryType::orMemoryTypeDevice) { + if (info->refcount == 0) { + if (openreg::mprotect( + info->pointer, info->size, F_PROT_READ | F_PROT_WRITE) != 0) { + return orErrorUnknown; + } + } + + info->refcount++; + } + + return orSuccess; + } + + orError_t protectNoLock(Block* info) { + if (info && info->type == orMemoryType::orMemoryTypeDevice) { + if (info->refcount == 1) { + if (openreg::mprotect(info->pointer, info->size, F_PROT_NONE) != 0) { + return orErrorUnknown; + } + } + + info->refcount--; + } + + return orSuccess; + } + + Block* getBlockInfoNoLock(const void* ptr) { + auto it = m_registry.upper_bound(const_cast(ptr)); + if (it != m_registry.begin()) { + --it; + const char* p_char = static_cast(ptr); + const char* base_char = static_cast(it->first); + if (p_char >= base_char && p_char < (base_char + it->second.size)) { + return &it->second; + } + } + + return nullptr; + } + + std::map m_registry; + std::mutex m_mutex; +}; + +} // namespace + +orError_t orMalloc(void** devPtr, size_t size) { + return MemoryManager::getInstance().allocate( + devPtr, size, orMemoryType::orMemoryTypeDevice); +} + +orError_t orFree(void* devPtr) { + return MemoryManager::getInstance().free(devPtr); +} + +orError_t orMallocHost(void** hostPtr, size_t size) { + return MemoryManager::getInstance().allocate( + hostPtr, size, orMemoryType::orMemoryTypeHost); +} + +orError_t orFreeHost(void* hostPtr) { + return MemoryManager::getInstance().free(hostPtr); +} + +orError_t orMemcpy( + void* dst, + const void* src, + size_t count, + orMemcpyKind kind) { + return MemoryManager::getInstance().memcpy(dst, src, count, kind); +} + +orError_t orMemcpyAsync( + void* dst, + const void* src, + size_t count, + orMemcpyKind kind, + orStream_t stream) { + if (!stream) { + return orErrorUnknown; + } + + auto& mm = MemoryManager::getInstance(); + + return orLaunchKernel( + stream, &MemoryManager::memcpy, &mm, dst, src, count, kind); +} + +orError_t orPointerGetAttributes( + orPointerAttributes* attributes, + const void* ptr) { + return MemoryManager::getInstance().getPointerAttributes(attributes, ptr); +} + +orError_t orMemoryUnprotect(void* devPtr) { + return MemoryManager::getInstance().unprotect(devPtr); +} + +orError_t orMemoryProtect(void* devPtr) { + return MemoryManager::getInstance().protect(devPtr); +} diff --git a/PyTorchSimDevice/third_party/openreg/csrc/memory.h b/PyTorchSimDevice/third_party/openreg/csrc/memory.h new file mode 100644 index 00000000..35851ac9 --- /dev/null +++ b/PyTorchSimDevice/third_party/openreg/csrc/memory.h @@ -0,0 +1,96 @@ +#pragma once + +#include +#include +#include + +#if defined(_WIN32) +#include +#else +#include +#include +#endif + +#define F_PROT_NONE 0x0 +#define F_PROT_READ 0x1 +#define F_PROT_WRITE 0x2 + +namespace openreg { + +void* mmap(size_t size) { +#if defined(_WIN32) + return VirtualAlloc(nullptr, size, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE); +#else + void* addr = ::mmap( + nullptr, + size, + PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, + -1, + 0); + return (addr == MAP_FAILED) ? nullptr : addr; +#endif +} + +void munmap(void* addr, size_t size) { +#if defined(_WIN32) + VirtualFree(addr, 0, MEM_RELEASE); +#else + ::munmap(addr, size); +#endif +} + +int mprotect(void* addr, size_t size, int prot) { +#if defined(_WIN32) + DWORD win_prot = 0; + DWORD old; + if (prot == F_PROT_NONE) { + win_prot = PAGE_NOACCESS; + } else { + win_prot = PAGE_READWRITE; + } + + return VirtualProtect(addr, size, win_prot, &old) ? 0 : -1; +#else + int native_prot = 0; + if (prot == F_PROT_NONE) + native_prot = PROT_NONE; + else { + if (prot & F_PROT_READ) + native_prot |= PROT_READ; + if (prot & F_PROT_WRITE) + native_prot |= PROT_WRITE; + } + + return ::mprotect(addr, size, native_prot); +#endif +} + +int alloc(void** mem, size_t alignment, size_t size) { +#ifdef _WIN32 + *mem = _aligned_malloc(size, alignment); + return *mem ? 0 : -1; +#else + return posix_memalign(mem, alignment, size); +#endif +} + +void free(void* mem) { +#ifdef _WIN32 + _aligned_free(mem); +#else + ::free(mem); +#endif +} + +long get_pagesize() { +#ifdef _WIN32 + SYSTEM_INFO si; + GetSystemInfo(&si); + return static_cast(si.dwPageSize); +#else + return sysconf(_SC_PAGESIZE); +#endif +} + +} // namespace openreg diff --git a/PyTorchSimDevice/third_party/openreg/csrc/stream.cpp b/PyTorchSimDevice/third_party/openreg/csrc/stream.cpp new file mode 100644 index 00000000..30f50b1a --- /dev/null +++ b/PyTorchSimDevice/third_party/openreg/csrc/stream.cpp @@ -0,0 +1,313 @@ +#include + +#include +#include +#include +#include +#include +#include +#include + +static std::mutex g_mutex; +static std::once_flag g_flag; +static std::vector> g_streams_per_device; + +static void initialize_registries() { + int device_count = 0; + orGetDeviceCount(&device_count); + g_streams_per_device.resize(device_count); +} + +struct orEventImpl { + std::mutex mtx; + std::condition_variable cv; + std::atomic completed{true}; + int device_index = -1; + bool timing_enabled{false}; + std::chrono::high_resolution_clock::time_point completion_time; +}; + +struct orEvent { + std::shared_ptr impl; +}; + +struct orStream { + std::queue> tasks; + std::mutex mtx; + std::condition_variable cv; + std::thread worker; + std::atomic stop_flag{false}; + int device_index = -1; + + orStream() { + worker = std::thread([this] { + while (true) { + std::function task; + { + std::unique_lock lock(this->mtx); + this->cv.wait(lock, [this] { + return this->stop_flag.load() || !this->tasks.empty(); + }); + if (this->stop_flag.load() && this->tasks.empty()) { + return; + } + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + task(); + } + }); + } + + ~orStream() { + stop_flag.store(true); + cv.notify_one(); + worker.join(); + } +}; + +orError_t openreg::addTaskToStream( + orStream_t stream, + std::function task) { + if (!stream) + return orErrorUnknown; + + { + std::lock_guard lock(stream->mtx); + stream->tasks.push(std::move(task)); + } + + stream->cv.notify_one(); + return orSuccess; +} + +orError_t orEventCreateWithFlags(orEvent_t* event, unsigned int flags) { + if (!event) + return orErrorUnknown; + + auto impl = std::make_shared(); + orGetDevice(&(impl->device_index)); + if (flags & orEventEnableTiming) { + impl->timing_enabled = true; + } + + *event = new orEvent{std::move(impl)}; + return orSuccess; +} + +orError_t orEventCreate(orEvent_t* event) { + return orEventCreateWithFlags(event, orEventDisableTiming); +} + +orError_t orEventDestroy(orEvent_t event) { + if (!event) + return orErrorUnknown; + + delete event; + return orSuccess; +} + +orError_t orEventRecord(orEvent_t event, orStream_t stream) { + if (!event || !stream) + return orErrorUnknown; + + auto event_impl = event->impl; + event_impl->completed.store(false); + auto record_task = [event_impl]() { + if (event_impl->timing_enabled) { + event_impl->completion_time = std::chrono::high_resolution_clock::now(); + } + + { + std::lock_guard lock(event_impl->mtx); + event_impl->completed.store(true); + } + + event_impl->cv.notify_all(); + }; + + return openreg::addTaskToStream(stream, record_task); +} + +orError_t orEventSynchronize(orEvent_t event) { + if (!event) + return orErrorUnknown; + + auto event_impl = event->impl; + std::unique_lock lock(event_impl->mtx); + event_impl->cv.wait(lock, [&] { return event_impl->completed.load(); }); + + return orSuccess; +} + +orError_t orEventQuery(orEvent_t event) { + if (!event) + return orErrorUnknown; + + return event->impl->completed.load() ? orSuccess : orErrorNotReady; +} + +orError_t orEventElapsedTime(float* ms, orEvent_t start, orEvent_t end) { + if (!ms || !start || !end) + return orErrorUnknown; + + auto start_impl = start->impl; + auto end_impl = end->impl; + + if (start_impl->device_index != end_impl->device_index) { + return orErrorUnknown; + } + + if (!start_impl->timing_enabled || !end_impl->timing_enabled) { + return orErrorUnknown; + } + + if (!start_impl->completed.load() || !end_impl->completed.load()) { + return orErrorUnknown; + } + + auto duration = end_impl->completion_time - start_impl->completion_time; + *ms = std::chrono::duration_cast>( + duration) + .count(); + + return orSuccess; +} + +orError_t orStreamCreateWithPriority( + orStream_t* stream, + [[maybe_unused]] unsigned int flag, + int priority) { + if (!stream) { + return orErrorUnknown; + } + + int min_p, max_p; + orDeviceGetStreamPriorityRange(&min_p, &max_p); + if (priority < min_p || priority > max_p) { + return orErrorUnknown; + } + + int current_device = 0; + orGetDevice(¤t_device); + + orStream_t new_stream = nullptr; + new_stream = new orStream(); + new_stream->device_index = current_device; + + { + std::lock_guard lock(g_mutex); + std::call_once(g_flag, initialize_registries); + g_streams_per_device[current_device].insert(new_stream); + } + + *stream = new_stream; + + return orSuccess; +} + +orError_t orStreamCreate(orStream_t* stream) { + int min_p, max_p; + orDeviceGetStreamPriorityRange(&min_p, &max_p); + + return orStreamCreateWithPriority(stream, 0, max_p); +} + +orError_t orStreamGetPriority( + [[maybe_unused]] orStream_t stream, + int* priority) { + // Since OpenReg has only one priority level, the following code + // returns 0 directly for convenience. + *priority = 0; + + return orSuccess; +} + +orError_t orStreamDestroy(orStream_t stream) { + if (!stream) + return orErrorUnknown; + + { + std::lock_guard lock(g_mutex); + + int device_idx = stream->device_index; + if (device_idx >= 0 && device_idx < g_streams_per_device.size()) { + g_streams_per_device[device_idx].erase(stream); + } + } + + delete stream; + return orSuccess; +} + +orError_t orStreamQuery(orStream_t stream) { + if (!stream) { + return orErrorUnknown; + } + + std::lock_guard lock(stream->mtx); + return stream->tasks.empty() ? orSuccess : orErrorNotReady; +} + +orError_t orStreamSynchronize(orStream_t stream) { + if (!stream) + return orErrorUnknown; + + orEvent_t event; + orEventCreate(&event); + orEventRecord(event, stream); + + orError_t status = orEventSynchronize(event); + orEventDestroy(event); + + return status; +} + +orError_t orStreamWaitEvent(orStream_t stream, orEvent_t event, unsigned int) { + if (!stream || !event) + return orErrorUnknown; + + auto event_impl = event->impl; + auto wait_task = [event_impl]() { + std::unique_lock lock(event_impl->mtx); + event_impl->cv.wait(lock, [&] { return event_impl->completed.load(); }); + }; + + return openreg::addTaskToStream(stream, wait_task); +} + +orError_t orDeviceGetStreamPriorityRange( + int* leastPriority, + int* greatestPriority) { + if (!leastPriority || !greatestPriority) { + return orErrorUnknown; + } + + // OpenReg have only one priority now. + *leastPriority = 0; + *greatestPriority = 0; + return orSuccess; +} + +orError_t orDeviceSynchronize(void) { + int current_device = 0; + orGetDevice(¤t_device); + + std::vector streams; + { + std::lock_guard lock(g_mutex); + std::call_once(g_flag, initialize_registries); + + auto& streams_on_device = g_streams_per_device[current_device]; + streams.assign(streams_on_device.begin(), streams_on_device.end()); + } + + for (orStream_t stream : streams) { + orError_t status = orStreamSynchronize(stream); + if (status != orSuccess) { + return status; + } + } + + return orSuccess; +} diff --git a/PyTorchSimDevice/third_party/openreg/example/example.cpp b/PyTorchSimDevice/third_party/openreg/example/example.cpp new file mode 100644 index 00000000..f00f1909 --- /dev/null +++ b/PyTorchSimDevice/third_party/openreg/example/example.cpp @@ -0,0 +1,112 @@ +#include "include/openreg.h" + +#include +#include +#include +#include + +struct MemoryGuard { + MemoryGuard(void* ptr) : ptr_(ptr) { + orMemoryUnprotect(ptr_); + } + ~MemoryGuard() { + orMemoryProtect(ptr_); + } + + private: + void* ptr_{}; +}; + +void add_kernel(float* out, float* a, float* b, int num) { + for (int i = 0; i < num; ++i) { + out[i] = a[i] + b[i]; + } +} + +int main() { + int device_count = 0; + orGetDeviceCount(&device_count); + + std::cout << "Current environment have " << device_count << " devices" + << std::endl; + + orSetDevice(0); + int current_device = -1; + orGetDevice(¤t_device); + + std::cout << "Current is " << current_device << " device" << std::endl; + + constexpr int num = 50000; + constexpr size_t size = num * sizeof(float); + + std::vector host_a(num), host_b(num), host_out(num, 0.0f); + std::iota(host_a.begin(), host_a.end(), 0.0f); + for (int i = 0; i < num; ++i) { + host_b[i] = 2.0f; + } + + float *dev_a, *dev_b, *dev_out; + orMalloc((void**)&dev_a, size); + orMalloc((void**)&dev_b, size); + orMalloc((void**)&dev_out, size); + + // There will be subsequent memory access operations, so memory protection + // needs to be released + MemoryGuard a{dev_a}; + MemoryGuard b{dev_b}; + MemoryGuard c{dev_out}; + + orStream_t stream1, stream2; + orEvent_t start_event, stop_event; + + orStreamCreate(&stream1); + orStreamCreate(&stream2); + orEventCreateWithFlags(&start_event, orEventEnableTiming); + orEventCreateWithFlags(&stop_event, orEventEnableTiming); + + // Copy input from host to device + orMemcpyAsync(dev_a, host_a.data(), size, orMemcpyHostToDevice, stream1); + orMemcpyAsync(dev_b, host_b.data(), size, orMemcpyHostToDevice, stream1); + + // Submit compute kernel and two events those are used for calculating time. + orEventRecord(start_event, stream1); + orLaunchKernel(stream1, add_kernel, dev_out, dev_a, dev_b, num); + orEventRecord(stop_event, stream1); + + // Synchronization between streams. + orStreamWaitEvent(stream2, stop_event, 0); + orMemcpyAsync(host_out.data(), dev_out, size, orMemcpyDeviceToHost, stream2); + orStreamSynchronize(stream2); + + std::cout << "All tasks have been submitted." << std::endl; + + float elapsed_ms = 0.0f; + orEventElapsedTime(&elapsed_ms, start_event, stop_event); + std::cout << "Kernel execution time: " << elapsed_ms << " ms" << std::endl; + + bool success = true; + for (int i = 0; i < num; ++i) { + if (std::abs(host_out[i] - (host_a[i] + host_b[i])) > 1e-5) { + std::cout << "Verification FAILED at index " << i << "! Expected " + << (host_a[i] + host_b[i]) << ", got " << host_out[i] + << std::endl; + success = false; + break; + } + } + if (success) { + std::cout << "Verification PASSED!" << std::endl; + } + + orFree(dev_a); + orFree(dev_b); + orFree(dev_out); + + orStreamDestroy(stream1); + orStreamDestroy(stream2); + + orEventDestroy(start_event); + orEventDestroy(stop_event); + + return 0; +} diff --git a/PyTorchSimDevice/third_party/openreg/include/openreg.h b/PyTorchSimDevice/third_party/openreg/include/openreg.h new file mode 100644 index 00000000..a5e4af55 --- /dev/null +++ b/PyTorchSimDevice/third_party/openreg/include/openreg.h @@ -0,0 +1,109 @@ +#pragma once + +#include + +#ifdef _WIN32 +#define OPENREG_EXPORT __declspec(dllexport) +#else +#define OPENREG_EXPORT __attribute__((visibility("default"))) +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +typedef enum orError_t { + orSuccess = 0, + orErrorUnknown = 1, + orErrorNotReady = 2 +} orError_t; + +typedef enum orMemcpyKind { + orMemcpyHostToHost = 0, + orMemcpyHostToDevice = 1, + orMemcpyDeviceToHost = 2, + orMemcpyDeviceToDevice = 3 +} orMemcpyKind; + +typedef enum orMemoryType { + orMemoryTypeUnmanaged = 0, + orMemoryTypeHost = 1, + orMemoryTypeDevice = 2 +} orMemoryType; + +struct orPointerAttributes { + orMemoryType type = orMemoryType::orMemoryTypeUnmanaged; + int device; + void* pointer; +}; + +typedef enum orEventFlags { + orEventDisableTiming = 0x0, + orEventEnableTiming = 0x1, +} orEventFlags; + +struct orStream; +struct orEvent; +typedef struct orStream* orStream_t; +typedef struct orEvent* orEvent_t; + +// Memory +OPENREG_EXPORT orError_t orMalloc(void** devPtr, size_t size); +OPENREG_EXPORT orError_t orFree(void* devPtr); +OPENREG_EXPORT orError_t orMallocHost(void** hostPtr, size_t size); +OPENREG_EXPORT orError_t orFreeHost(void* hostPtr); +OPENREG_EXPORT orError_t +orMemcpy(void* dst, const void* src, size_t count, orMemcpyKind kind); +OPENREG_EXPORT orError_t orMemcpyAsync( + void* dst, + const void* src, + size_t count, + orMemcpyKind kind, + orStream_t stream); +OPENREG_EXPORT orError_t +orPointerGetAttributes(orPointerAttributes* attributes, const void* ptr); +OPENREG_EXPORT orError_t orMemoryUnprotect(void* devPtr); +OPENREG_EXPORT orError_t orMemoryProtect(void* devPtr); + +// Device +OPENREG_EXPORT orError_t orGetDeviceCount(int* count); +OPENREG_EXPORT orError_t orSetDevice(int device); +OPENREG_EXPORT orError_t orGetDevice(int* device); +OPENREG_EXPORT orError_t +orDeviceGetStreamPriorityRange(int* leastPriority, int* greatestPriority); +OPENREG_EXPORT orError_t orDeviceSynchronize(void); + +// Stream +OPENREG_EXPORT orError_t orStreamCreateWithPriority( + orStream_t* stream, + unsigned int flags, + int priority); +OPENREG_EXPORT orError_t orStreamCreate(orStream_t* stream); +OPENREG_EXPORT orError_t orStreamGetPriority(orStream_t stream, int* priority); +OPENREG_EXPORT orError_t orStreamDestroy(orStream_t stream); +OPENREG_EXPORT orError_t orStreamQuery(orStream_t stream); +OPENREG_EXPORT orError_t orStreamSynchronize(orStream_t stream); +OPENREG_EXPORT orError_t +orStreamWaitEvent(orStream_t stream, orEvent_t event, unsigned int flags); + +// Event +OPENREG_EXPORT orError_t +orEventCreateWithFlags(orEvent_t* event, unsigned int flags); +OPENREG_EXPORT orError_t orEventCreate(orEvent_t* event); +OPENREG_EXPORT orError_t orEventDestroy(orEvent_t event); +OPENREG_EXPORT orError_t orEventRecord(orEvent_t event, orStream_t stream); +OPENREG_EXPORT orError_t orEventSynchronize(orEvent_t event); +OPENREG_EXPORT orError_t orEventQuery(orEvent_t event); +OPENREG_EXPORT orError_t +orEventElapsedTime(float* ms, orEvent_t start, orEvent_t end); + +#ifdef __cplusplus +} // extern "C" +#endif + +#ifdef __cplusplus + +#define OPENREG_H +#include "openreg.inl" + +#endif diff --git a/PyTorchSimDevice/third_party/openreg/include/openreg.inl b/PyTorchSimDevice/third_party/openreg/include/openreg.inl new file mode 100644 index 00000000..851be132 --- /dev/null +++ b/PyTorchSimDevice/third_party/openreg/include/openreg.inl @@ -0,0 +1,42 @@ +#ifndef OPENREG_H +#error "Don`t include openreg.inl directly, include openreg.h instead." +#endif + +#include +#include +#include + +namespace openreg { +OPENREG_EXPORT orError_t +addTaskToStream(orStream* stream, std::function task); +} + +template +OPENREG_EXPORT inline orError_t orLaunchKernel( + orStream* stream, + Func&& kernel_func, + Args&&... args) { + if (!stream) { + return orErrorUnknown; + } + +/* + * Some tests in PyTorch still use C++11, so we use conditional macro to + * select different approaches for different C++ version. + * + * Std::apply is only supported in C++17, so for C++11/14, std::bind is + * a more appropriate approach, but the former has better performance. + */ +#if __cplusplus >= 201703L + auto task = [func = std::forward(kernel_func), + args_tuple = + std::make_tuple(std::forward(args)...)]() mutable { + std::apply(func, std::move(args_tuple)); + }; +#else + auto task = + std::bind(std::forward(kernel_func), std::forward(args)...); +#endif + + return openreg::addTaskToStream(stream, std::move(task)); +} diff --git a/PyTorchSimDevice/torch_openreg/__init__.py b/PyTorchSimDevice/torch_openreg/__init__.py new file mode 100644 index 00000000..e8158391 --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/__init__.py @@ -0,0 +1,34 @@ +import sys +import os +import torch + + +if sys.platform == "win32": + from ._utils import _load_dll_libraries + + _load_dll_libraries() + del _load_dll_libraries + +import torch_openreg._C # type: ignore[misc] +import torch_openreg.openreg + +torch.utils.rename_privateuse1_backend("npu") +torch._register_device_module("npu", torch_openreg.openreg) +torch.utils.generate_methods_for_privateuse1_backend(for_storage=True) + +sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) +import PyTorchSimFrontend.extension_config # noqa: F401 +from PyTorchSimFrontend.mlir.mlir_codegen_backend import ExtensionWrapperCodegen +from PyTorchSimFrontend.mlir.mlir_scheduling import MLIRScheduling +torch._inductor.codegen.common.register_backend_for_device( + "npu", + lambda scheduling: MLIRScheduling(scheduling), + ExtensionWrapperCodegen +) + +torch_openreg.openreg.init() +sys.modules['torch.npu'] = torch_openreg.openreg + +def _autoload(): + # It is a placeholder function here to be registered as an entry point. + pass \ No newline at end of file diff --git a/PyTorchSimDevice/torch_openreg/_utils.py b/PyTorchSimDevice/torch_openreg/_utils.py new file mode 100644 index 00000000..1c26f475 --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/_utils.py @@ -0,0 +1,42 @@ +import ctypes +import glob +import os + + +def _load_dll_libraries() -> None: + openreg_dll_path = os.path.join(os.path.dirname(__file__), "lib") + + kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) + with_load_library_flags = hasattr(kernel32, "AddDllDirectory") + prev_error_mode = kernel32.SetErrorMode(0x0001) + + kernel32.LoadLibraryW.restype = ctypes.c_void_p + if with_load_library_flags: + kernel32.LoadLibraryExW.restype = ctypes.c_void_p + + os.add_dll_directory(openreg_dll_path) + + dlls = glob.glob(os.path.join(openreg_dll_path, "*.dll")) + path_patched = False + for dll in dlls: + is_loaded = False + if with_load_library_flags: + res = kernel32.LoadLibraryExW(dll, None, 0x00001100) + last_error = ctypes.get_last_error() + if res is None and last_error != 126: + err = ctypes.WinError(last_error) + err.strerror += f' Error loading "{dll}" or one of its dependencies.' + raise err + elif res is not None: + is_loaded = True + if not is_loaded: + if not path_patched: + os.environ["PATH"] = ";".join([openreg_dll_path] + [os.environ["PATH"]]) + path_patched = True + res = kernel32.LoadLibraryW(dll) + if res is None: + err = ctypes.WinError(ctypes.get_last_error()) + err.strerror += f' Error loading "{dll}" or one of its dependencies.' + raise err + + kernel32.SetErrorMode(prev_error_mode) diff --git a/PyTorchSimDevice/torch_openreg/csrc/CMakeLists.txt b/PyTorchSimDevice/torch_openreg/csrc/CMakeLists.txt new file mode 100644 index 00000000..2a29a89c --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/csrc/CMakeLists.txt @@ -0,0 +1,28 @@ +set(LIBRARY_NAME torch_bindings) + +file(GLOB_RECURSE SOURCE_FILES + "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" +) + +add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES}) + +target_include_directories(${LIBRARY_NAME} PRIVATE + ${PROJECT_SOURCE_DIR}/third_party/openreg +) + +target_link_libraries(${LIBRARY_NAME} PRIVATE torch_python_library torch_openreg) + +if(WIN32) + find_package(Python3 COMPONENTS Interpreter Development REQUIRED) + target_link_libraries(${LIBRARY_NAME} PRIVATE ${Python3_LIBRARIES}) +elseif(APPLE) + set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") +endif() + +target_link_directories(${LIBRARY_NAME} PRIVATE ${PYTORCH_INSTALL_DIR}/lib) + +install(TARGETS ${LIBRARY_NAME} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR} +) diff --git a/PyTorchSimDevice/torch_openreg/csrc/Module.cpp b/PyTorchSimDevice/torch_openreg/csrc/Module.cpp new file mode 100644 index 00000000..e4f3e8d1 --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/csrc/Module.cpp @@ -0,0 +1,280 @@ +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +static PyObject* _initExtension(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + + at::globalContext().lazyInitDevice(c10::DeviceType::PrivateUse1); + + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(arg), + "_get_default_generator expects an int, but got ", + THPUtils_typename(arg)); + auto idx = static_cast(THPUtils_unpackLong(arg)); + + return THPGenerator_initDefaultGenerator( + at::globalContext().defaultGenerator( + c10::Device(c10::DeviceType::PrivateUse1, idx))); + + END_HANDLE_TH_ERRORS +} + +PyObject* _setDevice(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice"); + auto device = THPUtils_unpackLong(arg); + + torch::utils::device_lazy_init(at::kPrivateUse1); + c10::openreg::set_device(static_cast(device)); + + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* _exchangeDevice(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice"); + auto device_index = THPUtils_unpackDeviceIndex(arg); + if (device_index < 0) { + return THPUtils_packInt32(-1); + } + + torch::utils::device_lazy_init(at::kPrivateUse1); + auto current_device = c10::openreg::ExchangeDevice(device_index); + + return THPUtils_packDeviceIndex(current_device); + END_HANDLE_TH_ERRORS +} + +PyObject* _getDevice(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + torch::utils::device_lazy_init(at::kPrivateUse1); + auto device = static_cast(c10::openreg::current_device()); + return THPUtils_packInt32(device); + END_HANDLE_TH_ERRORS +} + +PyObject* _getDeviceCount(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + return THPUtils_packUInt64(c10::openreg::device_count()); + END_HANDLE_TH_ERRORS +} + +PyObject* _isAutocastEnabled(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + if (c10::openreg::is_amp_enabled()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + +PyObject* _setAutocastEnabled(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + PyBool_Check(arg), + "set_autocast_enabled expects a bool, but got ", + THPUtils_typename(arg)); + c10::openreg::set_amp_enabled(arg == Py_True); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* _getAutocastDtype(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + THPDtype* dtype_obj = torch::getTHPDtype(c10::openreg::get_amp_dtype()); + Py_INCREF(dtype_obj); + return reinterpret_cast(dtype_obj); + END_HANDLE_TH_ERRORS +} + +PyObject* _setAutocastDtype(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPDtype_Check(arg), + "set_autocast_dtype expects a dtype, but got ", + THPUtils_typename(arg)); + THPDtype* dtype_obj = reinterpret_cast(arg); + at::ScalarType dtype = dtype_obj->scalar_type; + c10::openreg::set_amp_dtype(dtype); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* _getAmpSupportedDtype(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + PyObject* torch_mod = PyImport_ImportModule("torch"); + TORCH_CHECK(torch_mod != nullptr, "Failed to import torch module"); + + PyObject* float16 = PyObject_GetAttrString(torch_mod, "float16"); + PyObject* float32 = PyObject_GetAttrString(torch_mod, "float32"); + + PyObject* lst = PyList_New(1); + PyList_SetItem(lst, 0, float32); + //PyList_SetItem(lst, 1, float32); + + Py_DECREF(torch_mod); + return lst; + END_HANDLE_TH_ERRORS +} + +// Stream functions +PyObject* _streamCreate(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + torch::utils::device_lazy_init(at::kPrivateUse1); + orStream_t stream = nullptr; + orError_t err = orStreamCreate(&stream); + if (err != orSuccess) { + TORCH_CHECK(false, "Failed to create stream"); + } + return THPUtils_packInt64(reinterpret_cast(stream)); + END_HANDLE_TH_ERRORS +} + +PyObject* _streamCreateWithPriority(PyObject* self, PyObject* args) { + HANDLE_TH_ERRORS + TORCH_CHECK(PyTuple_Size(args) == 2, "stream_create_with_priority expects 2 arguments"); + PyObject* flags_obj = PyTuple_GetItem(args, 0); + PyObject* priority_obj = PyTuple_GetItem(args, 1); + TORCH_CHECK(THPUtils_checkLong(flags_obj), "flags must be an int"); + TORCH_CHECK(THPUtils_checkLong(priority_obj), "priority must be an int"); + unsigned int flags = static_cast(THPUtils_unpackLong(flags_obj)); + int priority = static_cast(THPUtils_unpackLong(priority_obj)); + + torch::utils::device_lazy_init(at::kPrivateUse1); + orStream_t stream = nullptr; + orError_t err = orStreamCreateWithPriority(&stream, flags, priority); + if (err != orSuccess) { + TORCH_CHECK(false, "Failed to create stream with priority"); + } + return THPUtils_packInt64(reinterpret_cast(stream)); + END_HANDLE_TH_ERRORS +} + +PyObject* _streamDestroy(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK(THPUtils_checkLong(arg), "stream_destroy expects an int"); + orStream_t stream = reinterpret_cast(THPUtils_unpackLong(arg)); + orError_t err = orStreamDestroy(stream); + if (err != orSuccess) { + TORCH_CHECK(false, "Failed to destroy stream"); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* _deviceSynchronize(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + torch::utils::device_lazy_init(at::kPrivateUse1); + + orError_t err; + Py_BEGIN_ALLOW_THREADS + err = orDeviceSynchronize(); + Py_END_ALLOW_THREADS + + if (err != orSuccess) { + TORCH_CHECK(false, "Failed to synchronize device"); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* _addTaskToStream(PyObject* self, PyObject* args) { + HANDLE_TH_ERRORS + TORCH_CHECK(PyTuple_Size(args) == 2, "add_task_to_stream expects 2 arguments"); + PyObject* stream_obj = PyTuple_GetItem(args, 0); + PyObject* callable_obj = PyTuple_GetItem(args, 1); + + TORCH_CHECK(THPUtils_checkLong(stream_obj), "stream must be an int"); + TORCH_CHECK(PyCallable_Check(callable_obj), "task must be callable"); + + orStream_t stream = reinterpret_cast(THPUtils_unpackLong(stream_obj)); + + Py_INCREF(callable_obj); + auto py_callable = std::shared_ptr(callable_obj, [](PyObject* obj) { + PyGILState_STATE gstate = PyGILState_Ensure(); + Py_DECREF(obj); + PyGILState_Release(gstate); + }); + + auto task = [py_callable]() { + PyGILState_STATE gstate = PyGILState_Ensure(); + try { + PyObject* result = PyObject_CallObject(py_callable.get(), nullptr); + if (result == nullptr) { + PyErr_Print(); + PyErr_Clear(); + } else { + Py_DECREF(result); + } + } catch (...) { + } + + PyGILState_Release(gstate); + }; + orError_t err = openreg::addTaskToStream(stream, task); + if (err != orSuccess) { + TORCH_CHECK(false, "Failed to add task to stream"); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyMethodDef methods[] = { + {"_init", _initExtension, METH_NOARGS, nullptr}, + {"_get_default_generator", _getDefaultGenerator, METH_O, nullptr}, + {"_get_device", _getDevice, METH_NOARGS, nullptr}, + {"_set_device", _setDevice, METH_O, nullptr}, + {"_exchangeDevice", _exchangeDevice, METH_O, nullptr}, + {"_get_device_count", _getDeviceCount, METH_NOARGS, nullptr}, + {"is_autocast_enabled", _isAutocastEnabled, METH_NOARGS, nullptr}, + {"set_autocast_enabled", _setAutocastEnabled, METH_O, nullptr}, + {"get_autocast_dtype", _getAutocastDtype, METH_NOARGS, nullptr}, + {"set_autocast_dtype", _setAutocastDtype, METH_O, nullptr}, + {"get_amp_supported_dtype", _getAmpSupportedDtype, METH_NOARGS, nullptr}, + // Stream functions + {"_stream_create", _streamCreate, METH_NOARGS, nullptr}, + {"_stream_destroy", _streamDestroy, METH_O, nullptr}, + + // Device functions + {"_device_synchronize", _deviceSynchronize, METH_NOARGS, nullptr}, + // Stream task functions + {"_add_task_to_stream", _addTaskToStream, METH_VARARGS, nullptr}, + {nullptr, nullptr, 0, nullptr}}; + +/* + * When ASAN is enabled, PyTorch modifies the dlopen flag during import, + * causing all global and weak symbols in _C.so and its dependent libraries + * to be exposed to the global symbol scope, which in turn causes + * subsequent symbols with the same name in other libraries to be intercepted. + * Therefore, it cannot be named initModule here, otherwise initModule + * in torch/csrc/Module.cpp will be called, resulting in failure. + */ +extern "C" OPENREG_EXPORT PyObject* initOpenRegModule(void) { + static struct PyModuleDef openreg_C_module = { + PyModuleDef_HEAD_INIT, "torch_openreg._C", nullptr, -1, methods}; + PyObject* mod = PyModule_Create(&openreg_C_module); + + return mod; +} diff --git a/PyTorchSimDevice/torch_openreg/csrc/stub.c b/PyTorchSimDevice/torch_openreg/csrc/stub.c new file mode 100644 index 00000000..4e02f9fd --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/csrc/stub.c @@ -0,0 +1,20 @@ +#include + +#ifdef _WIN32 +#define OPENREG_EXPORT __declspec(dllexport) +#else +#define OPENREG_EXPORT __attribute__((visibility("default"))) +#endif + +extern OPENREG_EXPORT PyObject* initOpenRegModule(void); + +#ifdef __cplusplus +extern "C" +#endif + + OPENREG_EXPORT PyObject* + PyInit__C(void); + +PyMODINIT_FUNC PyInit__C(void) { + return initOpenRegModule(); +} diff --git a/PyTorchSimDevice/torch_openreg/openreg/__init__.py b/PyTorchSimDevice/torch_openreg/openreg/__init__.py new file mode 100644 index 00000000..592011aa --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/openreg/__init__.py @@ -0,0 +1,334 @@ +import os +import threading + +import torch +from torch._dynamo.device_interface import register_interface_for_device +import torch_openreg._C # type: ignore[misc] + +from . import meta # noqa: F401 +from . import extension_device_op_overrides +from .extension_device_interface import ExtensionDeviceInterface + +_initialized = False +_default_streams = {} # Dictionary to store default streams per device +_tog_simulator = None # Singleton TOGSimulator instance +_launch_context = threading.local() # storage for launch_kernel context + +class device: + r"""Context-manager that changes the selected device. + + Args: + device (torch.device or int): device index to select. It's a no-op if + this argument is a negative integer or ``None``. + """ + + def __init__(self, device): + self.idx = torch.accelerator._get_device_index(device, optional=True) + self.prev_idx = -1 + + def __enter__(self): + self.prev_idx = torch_openreg._C._exchangeDevice(self.idx) + + def __exit__(self, type, value, traceback): + self.idx = torch_openreg._C._set_device(self.prev_idx) + return False + + +def is_available(): + return True + + +def device_count() -> int: + return torch_openreg._C._get_device_count() + + +def current_device(): + return torch_openreg._C._get_device() + + +def set_device(device) -> None: + return torch_openreg._C._set_device(device) + +def custom_device(): + return torch.device("npu:0") + +def init(): + _lazy_init() + + +def is_initialized(): + return _initialized + + +def _lazy_init(): + global _initialized, _tog_simulator + if is_initialized(): + return + + # Replace the global C++ binding with our custom dispatcher patch + # from PyTorchSimFrontend.mlir.mlir_sdpa_template import patched_scaled_dot_product_attention + # torch._C._nn.scaled_dot_product_attention = patched_scaled_dot_product_attention + + torch_openreg._C._init() + register_interface_for_device(custom_device(), ExtensionDeviceInterface) + _initialized = True + + # Set default SDPA backend to math-only for this device. + torch._C._set_sdp_use_flash(False) + torch._C._set_sdp_use_overrideable(False) + torch._C._set_sdp_use_math(True) + + # Create default streams for all devices + num_devices = device_count() + for device_idx in range(num_devices): + _default_streams[device_idx] = Stream() + +class Stream: + """Wrapper for OpenReg stream.""" + + def __init__(self, flags=0): + self._stream = torch_openreg._C._stream_create() + + def __del__(self): + # Interpreter shutdown can clear module globals before __del__ runs. + # Only destroy when both runtime handle and stream are still valid. + stream = getattr(self, "_stream", None) + backend = globals().get("torch_openreg", None) + c_api = getattr(backend, "_C", None) if backend is not None else None + if stream is None or c_api is None: + return + destroy = getattr(c_api, "_stream_destroy", None) + if destroy is None: + return + try: + destroy(stream) + except (AttributeError, TypeError): + # Ignore cleanup-time teardown ordering issues. + pass + + def launch_kernel(self, task): + """Add a Python callable kernel to this stream. + + Args: + task: A Python callable (function) to be executed in the stream + """ + torch_openreg._C._add_task_to_stream(self._stream, task) + + @property + def cdata(self): + """Get the underlying stream pointer (for internal use).""" + return self._stream + +def stream(flags=0): + return Stream(flags=flags) + +def default_stream(device=None): + _lazy_init() + if device is None: + device_idx = current_device() + else: + device_idx = torch.accelerator._get_device_index(device, optional=True) + if device_idx < 0: + device_idx = current_device() + + if device_idx not in _default_streams: + # Create default stream if it doesn't exist + _default_streams[device_idx] = Stream() + + return _default_streams[device_idx] + + +def launch_kernel(tog_path, attribute_path): + """Launch a kernel on TOGSimulator. + + Args: + tog_path: Path to TOG file + attribute_path: Path to attribute file + + Returns: + int: The kernel ID assigned to this launch + + """ + # Get TOGSimulator instance + sim = get_tog_simulator() + if sim is None: + raise RuntimeError("[torch.npu] TOGSimulator is not initialized. Call torch.npu.init() first.") + + device_idx = current_device() + stream_index, timestamp = get_launch_context() + # Create a task function that calls TOGSimulator.launch_kernel + def launch_task(): + return sim.launch_kernel(device_idx, stream_index, tog_path, attribute_path, timestamp) + + stream = default_stream() + stream.launch_kernel(launch_task) + +def synchronize(): + """Synchronize all streams on the current device. + + This function: + 1. Registers TOGSimulator.device_synchronize as a task on the default stream + 2. Calls the underlying device_synchronize to wait for all tasks to complete + """ + # Get TOGSimulator instance + sim = get_tog_simulator() + if sim is not None: + # Get current device index + device_idx = current_device() + + # Create a task function that calls TOGSimulator.device_synchronize + def sync_task(): + return sim.device_synchronize(device_idx) + + # Register as task on default stream + stream = default_stream() + stream.launch_kernel(sync_task) + + # Call underlying device_synchronize to wait for all tasks to complete + torch_openreg._C._device_synchronize() + +def get_tog_simulator(): + return _tog_simulator + +def set_tog_simulator(simulator): + """Set the global TOGSimulator instance. + + Args: + simulator: TOGSimulator instance or None + """ + global _tog_simulator + _tog_simulator = simulator + +def set_launch_context(stream_index=0, timestamp=0): + _launch_context.stream_index = stream_index + _launch_context.timestamp = timestamp + +def get_launch_context(): + stream_index = getattr(_launch_context, 'stream_index', 0) + timestamp = getattr(_launch_context, 'timestamp', 0) + return stream_index, timestamp + +class launch_context: + """Context manager for setting launch_kernel parameters. + + Args: + stream_index: Stream index (partition ID) to use for launch_kernel + timestamp: Timestamp in nanoseconds to use for launch_kernel + + Example: + with torch.npu.launch_context(stream_index=1, timestamp=1000): + model(input) + """ + + def __init__(self, stream_index=0, timestamp=0): + self.stream_index = stream_index + self.timestamp = timestamp + self.prev_stream_index = None + self.prev_timestamp = None + + def __enter__(self): + # Save previous context values + self.prev_stream_index = getattr(_launch_context, 'stream_index', 0) + self.prev_timestamp = getattr(_launch_context, 'timestamp', 0) + # Set new context values + set_launch_context(self.stream_index, self.timestamp) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Restore previous context values + _launch_context.stream_index = self.prev_stream_index + _launch_context.timestamp = self.prev_timestamp + return False + +def launch_model(model, *args, stream_index=0, timestamp=0, **kwargs): + """Launch a compiled model on TOGSimulator. + + Args: + model: Compiled model (torch.compile()) + *args: Model input arguments + stream_index: Stream index (partition ID). If None, uses context value. + timestamp: Timestamp in nanoseconds. If None, uses context value. + **kwargs: Additional keyword arguments for model execution + + Returns: + Model output (same as calling model(*args, **kwargs)) + + Note: + This function executes the compiled model and automatically launches + the generated kernels with the specified stream_index and timestamp. + If stream_index or timestamp are not provided, values from the current + context (set via launch_context() or set_launch_context()) are used. + """ + # Get stream_index and timestamp from parameters or context + with launch_context(stream_index=stream_index, timestamp=timestamp): + return model(*args, **kwargs) + +from .random import * # noqa: F403 +from .amp import * + +def eager_to_compile(op_name): + """ + Register an eager mode operation as a graph-based implementation using torch.compile(). + + Args: + op_name: Operator name (e.g., "aten::mul.Tensor") + + Example: + torch.npu.eager_to_compile("aten::mul.Tensor") + """ + def wrapper(*args, **kwargs): + @torch.compile(dynamic=False) + def dummy_graph(*args, **kwargs): + # Convert "aten::mul.Tensor" -> torch.ops.aten.mul.Tensor + namespace, op_path = op_name.split("::", 1) + op_path_parts = op_path.split(".") + op = torch.ops + for part in [namespace] + op_path_parts: + op = getattr(op, part) + return op(*args, **kwargs) + return dummy_graph(*args, **kwargs) + + torch.library.impl(op_name, "npu", wrapper) + +def register_eager_to_compile(ops): + """ + Register multiple operators at once using eager_to_compile. + + Args: + ops: List of operator names (e.g., ["aten::mul.Tensor", "aten::add.Tensor"]) + + Example: + torch.npu.register_eager_to_compile(["aten::mul.Tensor", "aten::add.Tensor"]) + """ + for op_name in ops: + eager_to_compile(op_name) + +__all__ = [ + "device", + "device_count", + "current_device", + "set_device", + "custom_device", + "initial_seed", + "is_available", + "init", + "is_initialized", + "random", + "manual_seed", + "manual_seed_all", + "get_rng_state", + "set_rng_state", + "is_autocast_enabled", + "set_autocast_enabled", + "get_autocast_dtype", + "set_autocast_dtype", + "get_amp_supported_dtype", + "stream", + "launch_kernel", + "launch_model", + "synchronize", + "get_tog_simulator", + "set_tog_simulator", + "eager_to_compile", + "register_eager_to_compile", +] diff --git a/PyTorchSimDevice/torch_openreg/openreg/amp.py b/PyTorchSimDevice/torch_openreg/openreg/amp.py new file mode 100644 index 00000000..0a9dfdf0 --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/openreg/amp.py @@ -0,0 +1,33 @@ +import torch + +import torch_openreg._C # type: ignore[misc] + +from . import _lazy_init + + +__all__ = [ + "is_autocast_enabled", + "set_autocast_enabled", + "get_autocast_dtype", + "set_autocast_dtype", + "get_amp_supported_dtype", +] + +def is_autocast_enabled(): + return torch_openreg._C.is_autocast_enabled() + + +def set_autocast_enabled(enabled: bool) -> None: + torch_openreg._C.set_autocast_enabled(enabled) + + +def get_autocast_dtype(): + return torch_openreg._C.get_autocast_dtype() + + +def set_autocast_dtype(dtype) -> None: + torch_openreg._C.set_autocast_dtype(dtype) + + +def get_amp_supported_dtype(): + return torch_openreg._C.get_amp_supported_dtype() \ No newline at end of file diff --git a/PyTorchSimDevice/torch_openreg/openreg/extension_device_interface.py b/PyTorchSimDevice/torch_openreg/openreg/extension_device_interface.py new file mode 100644 index 00000000..e5875ab7 --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/openreg/extension_device_interface.py @@ -0,0 +1,63 @@ +import torch +from torch._dynamo.device_interface import DeviceInterface, caching_worker_current_devices, caching_worker_device_properties + +class _ExtensionDeviceProperties: # FIXME: Dummy property values + name: str = "Extension_device" + platform_name: str + vendor: str + driver_version: str + version: str + max_compute_units: int + gpu_eu_count: int + max_work_group_size: int + max_num_sub_groups: int + sub_group_sizes: list[int] + has_fp16: bool + has_fp64: bool + has_atomic64: bool + has_bfloat16_conversions: bool + has_subgroup_matrix_multiply_accumulate: bool + has_subgroup_matrix_multiply_accumulate_tensor_float32: bool + has_subgroup_2d_block_io: bool + total_memory: int + multi_processor_count: int = 128 # gpu_subslice_count, num_sm + architecture: int + type: str + +_ExtensionDeviceProperties = _ExtensionDeviceProperties + +class ExtensionDeviceInterface(DeviceInterface): + class Worker: + @staticmethod + def set_device(device: int): + caching_worker_current_devices["extension_device"] = device + + @staticmethod + def current_device() -> int: + if "extension_device" in caching_worker_current_devices: + return caching_worker_current_devices["extension_device"] + return torch.xpu.current_device() + + @staticmethod + def get_device_properties(device: torch.types.Device = None) -> _ExtensionDeviceProperties: + if device is not None: + if isinstance(device, str): + device = torch.device(device) + assert device.type == "extension_device" + if isinstance(device, torch.device): + device = device.index + if device is None: + device = ExtensionDeviceInterface.Worker.current_device() + + if "extension_device" not in caching_worker_device_properties: + device_prop = [ + torch.cuda.get_device_properties(i) + for i in range(torch.cuda.device_count()) + ] + caching_worker_device_properties["extension_device"] = device_prop + + return _ExtensionDeviceProperties + + @staticmethod + def get_compute_capability(device: torch.types.Device = None): + return 36 \ No newline at end of file diff --git a/PyTorchSimDevice/torch_openreg/openreg/extension_device_op_overrides.py b/PyTorchSimDevice/torch_openreg/openreg/extension_device_op_overrides.py new file mode 100644 index 00000000..27a47357 --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/openreg/extension_device_op_overrides.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from textwrap import dedent + +from torch._inductor.codegen.common import DeviceOpOverrides, register_device_op_overrides +from torch._inductor.codegen.cpu_device_op_overrides import CpuDeviceOpOverrides + +class ExtensionDeviceOpOverrides(DeviceOpOverrides): + def import_get_raw_stream_as(self, name: str) -> str: + return dedent( + """ + def get_raw_stream(_): + return 0 + """ + ) + + def set_device(self, device_idx: int) -> str: + return "pass" + + def synchronize(self) -> str: + return "pass" + + def device_guard(self, device_idx: int) -> str: + return "pass" + +register_device_op_overrides("npu", ExtensionDeviceOpOverrides()) +register_device_op_overrides("cpu", CpuDeviceOpOverrides()) \ No newline at end of file diff --git a/PyTorchSimDevice/torch_openreg/openreg/meta.py b/PyTorchSimDevice/torch_openreg/openreg/meta.py new file mode 100644 index 00000000..c475e8e0 --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/openreg/meta.py @@ -0,0 +1,13 @@ +import torch + + +# LITERALINCLUDE START: CUSTOM OPERATOR META +lib = torch.library.Library("openreg", "IMPL", "Meta") # noqa: TOR901 + + +@torch.library.impl(lib, "custom_abs") +def custom_abs(self): + return torch.empty_like(self) + + +# LITERALINCLUDE END: CUSTOM OPERATOR META diff --git a/PyTorchSimDevice/torch_openreg/openreg/random.py b/PyTorchSimDevice/torch_openreg/openreg/random.py new file mode 100644 index 00000000..3f2e99fe --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/openreg/random.py @@ -0,0 +1,67 @@ +import torch + +import torch_openreg._C # type: ignore[misc] + +from . import _lazy_init, current_device, device_count + + +__all__ = [ + "get_rng_state", + "set_rng_state", + "manual_seed", + "manual_seed_all", + "initial_seed", + "_is_in_bad_fork", +] + + +def get_rng_state(device="openreg"): + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("openreg", device) + idx = device.index + if idx is None: + idx = current_device() + default_generator = torch_openreg._C._get_default_generator(idx) + return default_generator.get_state() + + +def set_rng_state(new_state, device="openreg"): + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("openreg", device) + idx = device.index + if idx is None: + idx = current_device() + default_generator = torch_openreg._C._get_default_generator(idx) + default_generator.set_state(new_state) + + +def initial_seed() -> int: + _lazy_init() + idx = current_device() + default_generator = torch_openreg._C._get_default_generator(idx) + return default_generator.initial_seed() + + +def manual_seed(seed: int) -> None: + seed = int(seed) + + idx = current_device() + default_generator = torch_openreg._C._get_default_generator(idx) + default_generator.manual_seed(seed) + + +def manual_seed_all(seed: int) -> None: + seed = int(seed) + + for idx in range(device_count()): + default_generator = torch_openreg._C._get_default_generator(idx) + default_generator.manual_seed(seed) + +def _is_in_bad_fork(): + # For NPU simulator, we don't have the same fork issues as CUDA + # Return False to indicate we're not in a bad fork state + return False \ No newline at end of file diff --git a/PyTorchSimFrontend/extension_codecache.py b/PyTorchSimFrontend/extension_codecache.py index 4d57b987..65c96f11 100644 --- a/PyTorchSimFrontend/extension_codecache.py +++ b/PyTorchSimFrontend/extension_codecache.py @@ -2,20 +2,30 @@ import re import shlex import subprocess +import torch -from torch._inductor.codecache import AsyncCompile, get_lock_dir, get_hash, write +from PyTorchSimFrontend import extension_config +from torch._inductor.codecache import get_hash, write +from torch._inductor.async_compile import AsyncCompile from AsmParser.tog_generator import tog_generator from PyTorchSimFrontend.mlir.mlir_caller_codegen import MLIRKernelCallerCodeGen -from PyTorchSimFrontend import extension_config from Simulator.simulator import FunctionalSimulator, CycleSimulator, TOGSimulator +# Configure logger for extension_codecache module (WARNING level by default) +logger = extension_config.setup_logger() + LOCK_TIMEOUT = 600 def hash_prefix(hash_value): return hash_value[1:12] def get_write_path(src_code): - return os.path.join(extension_config.CONFIG_TORCHSIM_DUMP_PATH, "outputs", hash_prefix(get_hash(src_code.strip()))) + return os.path.join(extension_config.CONFIG_TORCHSIM_DUMP_PATH, hash_prefix(get_hash(src_code.strip()))) + + +def get_lock_path(write_path): + """Return lock file path for the given write_path (per-source_code lock).""" + return os.path.join(write_path, ".compile.lock") def dump_metadata(args, arg_attributes, path): meta_path = os.path.join(path, "meta.txt") @@ -62,8 +72,17 @@ def mlir_compile_command(filename, vectorlane_size, vlen=256): f""" {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/llc \ -relocation-model=pic -march=riscv64 -O3 --stack-size-section \ - -mattr=+m,+f,+d,+a,+c,+v,+xsfvcp,zvl{vlen}b \ + -mattr=+m,+f,+d,+a,+c,+v,+zvfh,+xsfvcp,zvl{vlen}b \ + -filetype=obj \ {'--print-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_LLVM_IR else ''} \ + -O2 {filename}.ll -o {filename}.o + """, + ).strip(), + re.sub(r"[ \n]+", " ", + f""" + {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/llc \ + -relocation-model=pic -march=riscv64 -O3 --stack-size-section \ + -mattr=+m,+f,+d,+a,+c,+v,+zvfh,+xsfvcp,zvl{vlen}b \ -O2 {filename}.ll -o {filename}.s """, ).strip()] @@ -104,9 +123,10 @@ def mlir_gem5_compile_command(filename, sample_filename, tog_file, vectorlane_si f""" {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/llc \ -relocation-model=pic -march=riscv64 -O3 --stack-size-section \ - -mattr=+m,+f,+d,+a,+c,+v,+xsfvcp,zvl{vlen}b \ + -mattr=+m,+f,+d,+a,+c,+v,+zvfh,+xsfvcp,zvl{vlen}b \ + -filetype=obj \ {'--print-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_LLVM_IR else ''} \ - -O2 {sample_filename}.ll -o {sample_filename}.s + -O2 {sample_filename}.ll -o {sample_filename}.o """, ).strip()] @@ -140,12 +160,14 @@ def load(cls, source_code, key, input_path = write(source_code, "mlir", specified_dir=write_path) new_input_path = os.path.splitext(input_path)[0] raw_tog_path = new_input_path + "_tog.py" + tog_path = os.path.join(write_path, "tile_graph.onnx") sample_mlir_path = new_input_path + "_sample" + validation_binary_path = os.path.join(write_path, validation_binary_name) gem5_cmds = mlir_gem5_compile_command(new_input_path, sample_mlir_path, raw_tog_path, vectorlane_size) from filelock import FileLock - lock_dir = get_lock_dir() - lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + os.makedirs(write_path, exist_ok=True) + lock = FileLock(get_lock_path(write_path), timeout=LOCK_TIMEOUT) if spad_info is not None: link_option = f"-Wl,--section-start=.spad=0x{spad_info['spad_vaddr']:x}" @@ -159,35 +181,43 @@ def load(cls, source_code, opt_cmd = shlex.split(cmds[0]) translate_cmd = shlex.split(cmds[1]) llc_cmd = shlex.split(cmds[2]) + llc_asm_cmd = shlex.split(cmds[3]) with lock: try: subprocess.check_call(opt_cmd) subprocess.check_call(translate_cmd) subprocess.check_call(llc_cmd) + subprocess.check_call(llc_asm_cmd) except subprocess.CalledProcessError as e: - print("Command failed with exit code", e.returncode) - print("Error output:", e.output) + logger.error(f"Command failed with exit code {e.returncode}") + logger.error(f"Error output: {e.output.decode() if isinstance(e.output, bytes) else e.output}") assert(0) val_llvm_caller = MLIRKernelCallerCodeGen(extension_config.pytorchsim_functional_mode, arg_attributes) val_llvm_caller.generate_wrapper_file(write_path, validation_wrapper_name) val_llvm_caller.compile_wih_kernel(write_path, key, validation_wrapper_name, validation_binary_name, new_link_option) - target = os.path.join(write_path, validation_binary_name) + stack_size = val_llvm_caller.parse_stack_sizes(f"{write_path}/{key}.s", vlenb=vlenb) - spad_size = val_llvm_caller.get_spad_size(target) + spad_size = val_llvm_caller.get_spad_size(validation_binary_path) spad_usage = stack_size + spad_size # Spad usage per lane if extension_config.CONFIG_SPAD_INFO["spad_size"] < spad_usage: - print(f"[Warning] Scratchpad size exceeded: required {spad_usage} bytes, " - f"but only {extension_config.CONFIG_SPAD_INFO['spad_size']} bytes available.") + logger.debug( + f"Scratchpad size exceeded: required {spad_usage} bytes, " + f"but only {extension_config.CONFIG_SPAD_INFO['spad_size']} bytes available." + ) raise SpadOverflowError() + # Skip if TOG file already exists + if os.path.isfile(tog_path): + return key + # Launch tile graph generator gem5_sample_cmd = shlex.split(gem5_cmds[0]) gem5_translate_cmd = shlex.split(gem5_cmds[1]) gem5_llc_cmd = shlex.split(gem5_cmds[2]) - lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + lock = FileLock(get_lock_path(write_path), timeout=LOCK_TIMEOUT) with lock: try: result = subprocess.check_output(gem5_sample_cmd) @@ -196,8 +226,8 @@ def load(cls, source_code, subprocess.check_call(gem5_translate_cmd) subprocess.check_call(gem5_llc_cmd) except subprocess.CalledProcessError as e: - print("Command failed with exit code", e.returncode) - print("Error output:", e.output) + logger.error(f"Command failed with exit code {e.returncode}") + logger.error(f"Error output: {e.output.decode() if isinstance(e.output, bytes) else e.output}") assert(0) if not extension_config.pytorchsim_timing_mode: @@ -207,13 +237,10 @@ def load(cls, source_code, cycle_llvm_caller = MLIRKernelCallerCodeGen(False, arg_attributes, cycle_sim=True) cycle_llvm_caller.generate_wrapper_file(write_path, cycle_wrapper_name) cycle_llvm_caller.compile_wih_kernel(write_path, key + "_sample", cycle_wrapper_name, cycle_binary_name, link_option) - array_size = [] - for (arg_name, arg_attribute) in arg_attributes: - array_size.append(str(arg_attribute[2])) # Run cyclesim cyclesim = CycleSimulator() - cycle_list = cyclesim.compile_and_simulate(os.path.join(write_path, cycle_binary_name), " ".join(array_size), vectorlane_size, silent_mode=silent_mode) + cycle_list = cyclesim.compile_and_simulate(os.path.join(write_path, cycle_binary_name), vectorlane_size, silent_mode=silent_mode) # Create TOG w_offset, x_offset = vectorlane_size, vectorlane_size @@ -225,7 +252,7 @@ def load(cls, source_code, tile_graph_generator = tog_generator(origins) tile_graph_generator.load_file(raw_tog_path) tile_graph_generator.generate_tile_graph( - os.path.join(write_path, "tile_graph.onnx"), + tog_path, cycle_list=cycle_list, x_offset=x_offset, # FIXME. w_offset=w_offset, # FIXME. @@ -241,77 +268,50 @@ def __init__(self): self.cycle_binary_name = "cycle_binary" def mlir(self, source_code, arg_attributes=[], vectorlane_size=16, tile_size=[], spad_info=None, origins=None, silent_mode=False, **kwargs): + autotune = kwargs.get('autotune', False) def task(): key = MLIRCodeCache.load(source_code, valdiation_wrapper_name=self.validation_binary_name, validation_binary_name=self.validation_binary_name, arg_attributes=arg_attributes, vectorlane_size=vectorlane_size, tile_size=tile_size, spad_info=spad_info, origins=origins, - silent_mode=silent_mode, **kwargs) + silent_mode=autotune, **kwargs) return key future = self.submit(task) - if "loop_size" in kwargs: - loop_size = kwargs["loop_size"] - else: - loop_size = [] - # In the autotune mode, skip validation to speed up - autotune = kwargs.get('autotune', False) - validate = kwargs.get('validate', False) if not autotune else False - - def dummy_simulator(*args, **kwargs): + def run_kernel_simulation(*args, **kwargs): # Wait for compilation key = future.result() from filelock import FileLock - lock_dir = get_lock_dir() - lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + result_path = os.path.join(extension_config.CONFIG_TORCHSIM_DUMP_PATH, hash_prefix(key)) + lock = FileLock(get_lock_path(result_path), timeout=LOCK_TIMEOUT) with lock: # Run simulator pass - result_path = os.path.join(extension_config.CONFIG_TORCHSIM_DUMP_PATH, "outputs", hash_prefix(key)) # Dump arguments and meta data dump_metadata(args, arg_attributes, result_path) runtime_path = FunctionalSimulator.get_runtime_dump_path(result_path) - if not autotune and (extension_config.pytorchsim_functional_mode or validate): + if extension_config.pytorchsim_functional_mode and not autotune: funcsim = FunctionalSimulator(result_path, key) funcsim.run_spike(args, arg_attributes, runtime_path, self.validation_binary_name, vectorlane_size=vectorlane_size, spad_info=spad_info, - silent_mode=silent_mode) + silent_mode=autotune) + if not extension_config.pytorchsim_timing_mode: - return + return [float("inf")] + # Prepare arguments for launch kernel onnx_path = os.path.join(result_path, "tile_graph.onnx") - attribute_path = os.path.join(runtime_path, "attribute") - togsim_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "TOGSim") - TOGSim = TOGSimulator(togsim_path, extension_config.CONFIG_TOGSIM_CONFIG) - TOGSim.vectorlane_size = vectorlane_size - attribute_path = TOGSim.create_attribute_file(attribute_path, args, loop_size=loop_size) - result_path = TOGSim.simulation(onnx_path, attribute_path, silent_mode=silent_mode) - result = TOGSimulator.get_result_from_file(result_path) - return result - - def dryrun_simulator(*args, **kwargs): - key = future.result() - from filelock import FileLock - lock_dir = get_lock_dir() - lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) - with lock: - # Run simulator pass - result_path = os.path.join(extension_config.CONFIG_TORCHSIM_DUMP_PATH, "outputs", hash_prefix(key)) - # Dump arguments and meta data - dump_metadata(args, arg_attributes, result_path) - runtime_path = FunctionalSimulator.get_runtime_dump_path(result_path) + attribute_dir = os.path.join(runtime_path, "attribute") + kernel_attribute_path = TOGSimulator.write_kernel_attribute_file(attribute_dir, args) - # Todo. Support valude dependent mode for graph mode - if False: # extension_config.pytorchsim_functional_mode: - funcsim = FunctionalSimulator(result_path, key) - funcsim.run_spike(args, arg_attributes, - runtime_path, self.validation_binary_name, - vectorlane_size=vectorlane_size, spad_info=spad_info) - return result_path, runtime_path, None - - is_dryrun = int(os.environ.get('TOGSIM_EAGER_MODE', default=False)) and not autotune - target_simulator = dryrun_simulator if is_dryrun else dummy_simulator - target_simulator.arg_attributes = arg_attributes - target_simulator.future = future - return target_simulator + TOGSim = torch.npu.get_tog_simulator() + if not autotune and TOGSim is not None: + torch.npu.launch_kernel(onnx_path, kernel_attribute_path) + result = None # No result for non-autotune mode + else: + result_path = TOGSimulator.run_standalone( + onnx_path, kernel_attribute_path, autotune_mode=autotune) + result = TOGSimulator.get_result_from_file(result_path) + return result + return run_kernel_simulation diff --git a/PyTorchSimFrontend/extension_config.py b/PyTorchSimFrontend/extension_config.py index 239bbefe..5dec8a4b 100644 --- a/PyTorchSimFrontend/extension_config.py +++ b/PyTorchSimFrontend/extension_config.py @@ -1,63 +1,98 @@ import os import sys import importlib -import json +import yaml +import logging CONFIG_TORCHSIM_DIR = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') CONFIG_GEM5_PATH = os.environ.get('GEM5_PATH', default="/workspace/gem5/build/RISCV/gem5.opt") CONFIG_TORCHSIM_LLVM_PATH = os.environ.get('TORCHSIM_LLVM_PATH', default="/usr/bin") +CONFIG_TORCHSIM_TOG_HOST_CC = os.environ.get("TORCHSIM_TOG_HOST_CC", "gcc") + +def _default_tog_host_cflags(): + """Host flags for ``dlopen``'d ``*_tog.so`` / ``tile_operation_graph.so``.""" + if os.environ.get("TORCHSIM_TOG_HOST_CFLAGS"): + return os.environ["TORCHSIM_TOG_HOST_CFLAGS"] + if True: #int(os.environ.get("TORCHSIM_TOG_SO_DEBUG", "0")): + return ( + "-g -Og -fno-omit-frame-pointer -fPIC -std=c11 " + "-Wall -Wextra -Wno-unused-variable -Wno-unused-parameter" + ) + return ( + "-O2 -fPIC -std=c11 -Wall -Wextra -Wno-unused-variable -Wno-unused-parameter" + ) + + +CONFIG_TORCHSIM_TOG_HOST_CFLAGS = _default_tog_host_cflags() + + +def _default_tog_host_ldflags(): + if os.environ.get("TORCHSIM_TOG_HOST_LDFLAGS"): + return os.environ["TORCHSIM_TOG_HOST_LDFLAGS"] + # Keep debug sections in .so; optional build-id helps GDB locate DWARF. + base = "-shared" + if int(os.environ.get("TORCHSIM_TOG_SO_DEBUG", "0")): + return base + " -Wl,--build-id" + return base + + +CONFIG_TORCHSIM_TOG_HOST_LDFLAGS = _default_tog_host_ldflags() + CONFIG_TORCHSIM_DUMP_MLIR_IR = int(os.environ.get("TORCHSIM_DUMP_MLIR_IR", default=False)) CONFIG_TORCHSIM_DUMP_LLVM_IR = int(os.environ.get("TORCHSIM_DUMP_LLVM_IR", default=False)) +CONFIG_TORCHSIM_DUMP_PATH = os.environ.get("TORCHSIM_DUMP_PATH", os.path.join(CONFIG_TORCHSIM_DIR, "outputs")) +CONFIG_TORCHSIM_LOG_PATH = os.environ.get("TORCHSIM_LOG_PATH", os.path.join(CONFIG_TORCHSIM_DIR, "togsim_results")) +os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CONFIG_TORCHSIM_DUMP_PATH, ".torchinductor") def __getattr__(name): # TOGSim config config_path = os.environ.get('TOGSIM_CONFIG', - default=f"{CONFIG_TORCHSIM_DIR}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json") + default=f"{CONFIG_TORCHSIM_DIR}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml") if name == "CONFIG_TOGSIM_CONFIG": return config_path - config_json = json.load(open(config_path, 'r')) + + with open(config_path, 'r') as f: + config_yaml = yaml.safe_load(f) # Hardware info config if name == "vpu_num_lanes": - return config_json["vpu_num_lanes"] + return config_yaml["vpu_num_lanes"] if name == "CONFIG_SPAD_INFO": return { "spad_vaddr" : 0xD0000000, "spad_paddr" : 0x2000000000, - "spad_size" : config_json["vpu_spad_size_kb_per_lane"] << 10 # Note: spad size per lane + "spad_size" : config_yaml["vpu_spad_size_kb_per_lane"] << 10 # Note: spad size per lane } - if name == "CONFIG_PRECISION": - return 4 # 32bit if name == "CONFIG_NUM_CORES": - return config_json["num_cores"] + return config_yaml["num_cores"] if name == "vpu_vector_length_bits": - return config_json["vpu_vector_length_bits"] + return config_yaml["vpu_vector_length_bits"] if name == "pytorchsim_functional_mode": - return config_json['pytorchsim_functional_mode'] + return config_yaml['pytorchsim_functional_mode'] if name == "pytorchsim_timing_mode": - return config_json['pytorchsim_timing_mode'] + return config_yaml['pytorchsim_timing_mode'] # Mapping strategy if name == "codegen_mapping_strategy": - codegen_mapping_strategy = config_json["codegen_mapping_strategy"] + codegen_mapping_strategy = config_yaml["codegen_mapping_strategy"] assert(codegen_mapping_strategy in ["heuristic", "autotune", "external-then-heuristic", "external-then-autotune"]), "Invalid mapping strategy!" return codegen_mapping_strategy if name == "codegen_external_mapping_file": - return config_json["codegen_external_mapping_file"] + return config_yaml["codegen_external_mapping_file"] # Autotune config if name == "codegen_autotune_max_retry": - return config_json["codegen_autotune_max_retry"] + return config_yaml["codegen_autotune_max_retry"] if name == "codegen_autotune_template_topk": - return config_json["codegen_autotune_template_topk"] + return config_yaml["codegen_autotune_template_topk"] # Compiler Optimization if name == "codegen_compiler_optimization": - opt_level = config_json["codegen_compiler_optimization"] + opt_level = config_yaml["codegen_compiler_optimization"] valid_opts = { "fusion", "reduction_epilogue", @@ -67,7 +102,7 @@ def __getattr__(name): "multi_tile_conv", "subtile" } - if opt_level == "all" or opt_level is "none": + if opt_level == "all" or opt_level == "none": pass elif isinstance(opt_level, list): # Check if provided list contains only valid options @@ -98,13 +133,6 @@ def __getattr__(name): if name == "CONFIG_TOGSIM_DEBUG_LEVEL": return os.environ.get("TOGSIM_DEBUG_LEVEL", "") - if name == "CONFIG_TORCHSIM_DUMP_PATH": - return os.environ.get('TORCHSIM_DUMP_PATH', default = CONFIG_TORCHSIM_DIR) - if name == "CONFIG_TORCHSIM_LOG_PATH": - return os.environ.get('TORCHSIM_DUMP_LOG_PATH', default = os.path.join(CONFIG_TORCHSIM_DIR, "togsim_results")) - - if name == "CONFIG_TOGSIM_EAGER_MODE": - return int(os.environ.get("TOGSIM_EAGER_MODE", default=False)) # SRAM Buffer allocation plan def load_plan_from_module(module_path): @@ -132,4 +160,43 @@ def load_plan_from_module(module_path): CONFIG_USE_TIMING_POOLING = int(os.environ.get('TORCHSIM_USE_TIMING_POOLING', default=0)) -CONFIG_DEBUG_MODE = int(os.environ.get('TORCHSIM_DEBUG_MODE', default=0)) \ No newline at end of file +CONFIG_DEBUG_MODE = int(os.environ.get('TORCHSIM_DEBUG_MODE', default=0)) + + +def setup_logger(name=None, level=None): + """ + Setup a logger with consistent formatting across all modules. + + Args: + name: Logger name (default: __name__ of calling module) + level: Logging level (default: DEBUG if CONFIG_DEBUG_MODE else INFO) + + Returns: + Logger instance + """ + if name is None: + import inspect + # Get the calling module's name + frame = inspect.currentframe().f_back + name = frame.f_globals.get('__name__', 'PyTorchSim') + + # Convert logger name to lowercase + name = name.lower() + logger = logging.getLogger(name) + + # Only configure if not already configured (avoid duplicate handlers) + if not logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter( + fmt='[%(asctime)s.%(msecs)03d] [%(levelname)s] [%(name)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + + # Set log level + if level is None: + level = logging.DEBUG if CONFIG_DEBUG_MODE else logging.INFO + logger.setLevel(level) + + return logger \ No newline at end of file diff --git a/PyTorchSimFrontend/extension_device.cpp b/PyTorchSimFrontend/extension_device.cpp deleted file mode 100644 index 1a02bfe3..00000000 --- a/PyTorchSimFrontend/extension_device.cpp +++ /dev/null @@ -1,388 +0,0 @@ -#include -#include - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -static uint64_t op_counter = 0; -static uint64_t last_saved_value = 0; - -// register guard -namespace at { -namespace detail { - -C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl); - -}} // namespace at::detail - -// basic dummy add function -at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { - op_counter += 1; - // Since this custom device is just for testing, not bothering to implement kernels. - return at::empty(self.sizes(), self.options()); -} - -// basic dummy mul function -at::Tensor custom_mul_Tensor(const at::Tensor & self, const at::Tensor & other) { - op_counter += 1; - // Since this custom device is just for testing, not bothering to implement kernels. - return at::empty(self.sizes(), self.options()); -} - -at::Tensor _reinterpret_tensor( - const at::Tensor& self, - c10::IntArrayRef size, - c10::IntArrayRef stride, - int64_t offset_increment) { - at::Tensor self_ = at::detail::make_tensor( - c10::Storage(self.storage()), self.key_set(), self.dtype()); - auto* self_tmp_ = self_.unsafeGetTensorImpl(); - self_tmp_->set_storage_offset(self.storage_offset() + offset_increment); - self_tmp_->set_sizes_and_strides(size, stride); - return self_; -} - -at::Tensor& zero_inplace_batching_rule(at::Tensor &self) { - op_counter += 1; - // Since this custom device is just for testing, not bothering to implement kernels. - return self; -} - -const at::Tensor& custom_resize_(const at::Tensor& self, at::IntArrayRef size, - std::optional optional_memory_format) { - at::TensorImpl* tensor_impl = self.unsafeGetTensorImpl(); - tensor_impl->set_sizes_contiguous(size); - const auto itemsize = tensor_impl->dtype().itemsize(); - const auto offset = tensor_impl->storage_offset(); - const auto storage_size = at::detail::computeStorageNbytesContiguous(size, itemsize, offset); - // Dummy device is using cpu allocator, so here just call cpu - // function maybe_resize_storage_cpu in aten/src/ATen/native/Resize.h - // to get a sufficient memory space. - at::native::maybe_resize_storage_cpu(tensor_impl, storage_size); - if (optional_memory_format.has_value()) { - auto memory_format = - optional_memory_format.value(); - TORCH_CHECK( - memory_format != at::MemoryFormat::Preserve, - "Unsupported memory format", - memory_format); - tensor_impl->empty_tensor_restride(memory_format); - } - return self; -} - -// basic dummy eq function: Only support CPU -at::Tensor custom_to_device( - const at::Tensor & self, - at::Device device, - at::ScalarType dtype, - bool non_blocking, - bool copy, - c10::optional memory_format) { - TORCH_CHECK(self.is_cpu() || self.device().type() == c10::DeviceType::PrivateUse1, "Dummy test only allows copy from cpu -> dummy device."); - TORCH_CHECK(device.is_cpu() || device.type() == c10::DeviceType::PrivateUse1, "Dummy test only allows copy from cpu -> dummy device."); - // Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous. - TORCH_CHECK(self.scalar_type() == dtype); - TORCH_CHECK(self.is_contiguous()); - - op_counter += 1; - if (device != at::DeviceType::CPU) { - return at::empty(self.sizes(), self.options()); - } - - auto out = at::empty(self.sizes(), dtype, self.options().layout(), device, false, memory_format); - memcpy(out.mutable_data_ptr(), self.mutable_data_ptr(), self.nbytes()); - // Since this custom device is just for testing, not bothering to implement kernels. - return out; -} - - -// A dummy allocator for our custom device, that secretly uses the CPU -struct DummyCustomAllocator final : at::Allocator { - DummyCustomAllocator() = default; - at::DataPtr allocate(size_t nbytes) const override { - void* data = c10::alloc_cpu(nbytes); - return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, 0)}; - } - - static void ReportAndDelete(void* ptr) { - if (!ptr) { - return; - } - c10::free_cpu(ptr); - } - - at::DeleterFnPtr raw_deleter() const override { - return &ReportAndDelete; - } -}; - -// Register our dummy allocator -static DummyCustomAllocator global_custom_alloc; -REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_custom_alloc); - -at::Tensor & custom_fill__scalar(at::Tensor & self, const at::Scalar & value) { - TORCH_CHECK(self.device().type() == c10::DeviceType::PrivateUse1, "Dummy test only allows dummy device."); - TORCH_CHECK(self.is_contiguous()); - // TORCH_CHECK(self.scalar_type() == c10::ScalarType::Float); - - op_counter += 1; - if (self.scalar_type() == c10::ScalarType::Float) { - auto _data = static_cast(self.mutable_data_ptr()); - for (size_t idx = 0; idx < self.numel(); idx++) { - _data[idx] = value.toFloat(); - } - return self; - } else if (self.scalar_type() == c10::ScalarType::Int) { - auto _data = static_cast(self.mutable_data_ptr()); - for (size_t idx = 0; idx < self.numel(); idx++) { - _data[idx] = value.toInt(); - } - return self; - } else if (self.scalar_type() == c10::ScalarType::Long) { - auto _data = static_cast(self.mutable_data_ptr()); - for (size_t idx = 0; idx < self.numel(); idx++) { - _data[idx] = value.toLong(); - } - return self; - } else { - TORCH_CHECK(false, "Unsupported scalar type."); - } - - return self; -} - -at::Tensor unsafe_create_cpu_tensor_from_dummy_tensor(const at::Tensor& src) { - // TORCH_CHECK(src.device().type() == c10::DeviceType::PrivateUse1, - // "Only support dummy device."); - const auto& sizes_ = src.sizes(); - const auto& strides_ = src.strides(); - auto storage_offset_ = src.storage_offset(); - at::detail::check_size_nonnegative(sizes_); - - size_t size_bytes = at::detail::computeStorageNbytes(sizes_, strides_, - src.element_size(), - storage_offset_); - - at::DataPtr data_ptr = - c10::InefficientStdFunctionContext::makeDataPtr(src.storage().mutable_data_ptr().get(), - [](void*){}, at::kCPU); - - c10::Storage storage{c10::Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr), - /*allocator=*/&global_custom_alloc, /*resizeable=*/false}; - - constexpr c10::DispatchKeySet cpu_ks(c10::DispatchKey::CPU); - at::Tensor tensor = at::detail::make_tensor( - std::move(storage), cpu_ks, src.dtype()); - - c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl(); - tensor_impl->set_sizes_and_strides(sizes_, strides_); - tensor_impl->set_storage_offset(storage_offset_); - return tensor; -} - -// basic dummy copy_() function, so we can copy from the custom device to/from CPU -at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) { - TORCH_CHECK( - self.is_cpu() || self.device().type() == c10::DeviceType::PrivateUse1, - "Dummy test only allows copy from cpu -> dummy device."); - TORCH_CHECK( - dst.is_cpu() || dst.device().type() == c10::DeviceType::PrivateUse1, - "Dummy test only allows copy from cpu -> dummy device."); - - // Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous. - TORCH_CHECK(self.sizes() == dst.sizes()); - - const bool same_dtype = (self.scalar_type() == dst.scalar_type()); - const bool both_contig = self.is_contiguous() && dst.is_contiguous(); - - // 1) fast path - if (same_dtype && both_contig) { - std::memcpy(dst.mutable_data_ptr(), - self.data_ptr(), - dst.storage().nbytes()); - return dst; - } - - // 2) slow path - at::Tensor cpu_self = unsafe_create_cpu_tensor_from_dummy_tensor(self); - at::Tensor cpu_dst = unsafe_create_cpu_tensor_from_dummy_tensor(dst); - if (!same_dtype) { - cpu_self = cpu_self.to(cpu_dst.scalar_type(), /*non_blocking=*/false, /*copy=*/true); - } - cpu_dst.copy_(cpu_self); - return dst; -} - -at::Tensor custom__copy_from_and_resize(const at::Tensor& self, const at::Tensor& dst) { - return custom__copy_from(self, dst, false); -} - -at::Tensor& custom_abs_out(const at::Tensor& self, at::Tensor& out) { - return at::native::abs_out(self, out); -} - -at::Tensor custom_empty_strided(c10::IntArrayRef size, c10::IntArrayRef stride, c10::optional dtype_opt, c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { - op_counter += 1; - constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1); - auto dtype = c10::dtype_or_default(dtype_opt); - return at::detail::empty_strided_generic(size, stride, &global_custom_alloc, private_use_ks, dtype); -} - -at::Tensor custom_empty(c10::IntArrayRef size, c10::optional dtype_opt, c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt, c10::optional optional_memory_format) { - op_counter += 1; - - constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1); - auto dtype = c10::dtype_or_default(dtype_opt); - return at::detail::empty_generic(size, &global_custom_alloc, private_use_ks, dtype, optional_memory_format); -} - -at::Tensor& custom_arange_start_out_impl( - const c10::Scalar& start, - const c10::Scalar& end, - const c10::Scalar& step, - at::Tensor& out) { - //const int64_t n = arange_len(start.toDouble(), end.toDouble(), step.toDouble()); - //at::native::resize_output(out, {n}); - return out; -} - -static at::Tensor custom_to_dtype_impl(const at::Tensor& self, - c10::ScalarType dtype, - bool non_blocking, bool copy, - c10::optional memory_format) { - return at::native::to(self, dtype, non_blocking, copy, memory_format); -} - -// With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend. -// For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key. -// Later in this file, we map a custom device to the PrivateUse1 device type, -// which allows user code that puts a tensor on your custom_device to eventually get plumbed -// into the kernels registered here. -// -// This macro registers your kernels to the PyTorch Dispatcher. -// More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/. -TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { - m.impl("to.Device", &custom_to_device); - m.impl("to.dtype", &custom_to_dtype_impl); - m.impl("fill_.Scalar", &custom_fill__scalar); - m.impl("_copy_from", &custom__copy_from); - m.impl("_copy_from_and_resize", &custom__copy_from_and_resize); - m.impl("empty_strided", &custom_empty_strided); - m.impl("empty.memory_format", &custom_empty); - m.impl("as_strided", at::native::as_strided_tensorimpl); - m.impl("view", at::native::view); - m.impl("arange.start_out", &custom_arange_start_out_impl); -} - -TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) { - m.impl("to.dtype", &custom_to_dtype_impl); -} - -TORCH_LIBRARY_FRAGMENT(aten, m) { -m.def( - "_reinterpret_tensor(Tensor self, int[] size, int[] stride, int offset_increment=0) -> Tensor", - torch::dispatch( - c10::DispatchKey::AutogradPrivateUse1, _reinterpret_tensor), - {at::Tag::pt2_compliant_tag}); -} - -void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { - at::native::cpu_fallback(op, stack); -} - -TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { - m.impl("add.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("add.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("abs.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("sub.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("mul.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("div.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("pow.Tensor_Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("zero_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_foreach_add.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("index.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("triu_indices", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("neg.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("sum.IntList_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("eq.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("all.all_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_local_scalar_dense", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_log_softmax", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_log_softmax_backward_data", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("mse_loss.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("nll_loss_forward", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("nll_loss_backward", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_foreach_lerp_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_foreach_mul_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_foreach_addcmul_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_foreach_sqrt", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_foreach_div_.ScalarList", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_foreach_add_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_foreach_addcdiv_.ScalarList", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_foreach_add_.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("cat.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_native_multi_head_attention", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("resize_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("exp.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); -} - -// This basic implementation doesn't bother dealing with different device indices -// (e.g. custom_device:0 vs. custom_device:1). -// We could do that by letting the user pass in a device index in our exposed device function. -// Note that if you do that, you'll also need to register a device guard to core. -// See `c10/core/impl/DeviceGuardImplInterface.h:C10_REGISTER_GUARD_IMPL`. -c10::Device get_custom_device() { - return c10::Device(c10::DeviceType::PrivateUse1, 0); -} - -bool custom_op_called() { - bool called = false; - if (op_counter > last_saved_value) { - called = true; - last_saved_value = op_counter; - } - return called; -} - -class PrivateGeneratorImpl : public at::CPUGeneratorImpl { -public: - // Constructors - PrivateGeneratorImpl(c10::DeviceIndex device_index) { - device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index); - key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1); - } - ~PrivateGeneratorImpl() override = default; -}; - -// this is used to register generator -at::Generator make_generator_privateuse1(c10::DeviceIndex device_index) { - return at::make_generator(device_index); -} - -void register_generator() { - REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1) -} - -// Here, we're exposing a custom device object that corresponds to our custom backend. -// We do this using pybind: exposing an "extension_name.custom_device()" function in python, -// that's implemented in C++. -// The implementation in this file maps directly to the `PrivateUse1` device type. -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("custom_device", &get_custom_device, "get custom device object"); - m.def("custom_op_called", &custom_op_called, "check if our custom function was called"); - m.def("register_generator", ®ister_generator, "register generator for custom device"); -} \ No newline at end of file diff --git a/PyTorchSimFrontend/extension_op.py b/PyTorchSimFrontend/extension_op.py index 786e7398..e6351101 100644 --- a/PyTorchSimFrontend/extension_op.py +++ b/PyTorchSimFrontend/extension_op.py @@ -46,9 +46,6 @@ class MLIRExternKernelChoice(ExternKernelChoice): def call_name(self): - is_dryrun = int(os.environ.get('TOGSIM_EAGER_MODE', default=False)) - if is_dryrun: - return f"yield from sparse_mm_dummy_stonne_outer" return f"torch.ops.extension_op.{self.name}" custom_lib = torch.library.Library("extension_op", "DEF") @@ -275,10 +272,8 @@ def prepare_outer_product_matrix(a, b, out): def sparse_mm_stonne_outer(a, b, out): onnx_path, attribute_path, c_result_path = prepare_outer_product_matrix(a, b, out) - togsim_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "TOGSim") - stonne_config_path = f'{extension_config.CONFIG_TORCHSIM_DIR}/configs/stonne_single_c1_simple_noc.json' - TOGSim = TOGSimulator(togsim_path, stonne_config_path) - result_path = TOGSim.simulation(onnx_path) + stonne_config_path = f'{extension_config.CONFIG_TORCHSIM_DIR}/configs/stonne_single_c1_simple_noc.yml' + result_path = TOGSimulator.run_standalone(onnx_path, config_path=stonne_config_path) TOGSimulator.get_result_from_file(result_path) # Load result data diff --git a/PyTorchSimFrontend/extension_utils.py b/PyTorchSimFrontend/extension_utils.py new file mode 100644 index 00000000..0418cacd --- /dev/null +++ b/PyTorchSimFrontend/extension_utils.py @@ -0,0 +1,26 @@ +import sympy +import torch + +""" +NOTE: Temporary File + +This file contains functions that were removed or changed in newer versions +of PyTorch. It is kept here only to temporarily enable compatibility while +upgrading to PyTorch 2.8 from PyTorch 2.2. + +These functions will eventually be integrated into the appropriate source files +or removed once no longer needed. + +This file is not intended to be permanent and should be deleted in the future. +""" + +def free_symbol_startswith(index: sympy.Expr, prefix: str): + return any(v.name.startswith(prefix) for v in index.free_symbols) + +def sympy_symbol(name: str) -> sympy.Symbol: + # This should never be used for creating shape/stride symbols, as those + # should all be allocated before Inductor. + assert name[0] != "s" + # NOTE: shape symbols are positive (> 0), but index variables are only + # non-negative (>= 0). + return sympy.Symbol(name, integer=True, nonnegative=True) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_autotune.py b/PyTorchSimFrontend/mlir/mlir_autotune.py index 988408ea..fe1f86a1 100644 --- a/PyTorchSimFrontend/mlir/mlir_autotune.py +++ b/PyTorchSimFrontend/mlir/mlir_autotune.py @@ -21,7 +21,7 @@ def hash_prefix(hash_value): return hash_value[1:12] def get_write_path(src_code): - return os.path.join(extension_config.CONFIG_TORCHSIM_DUMP_PATH, "outputs", hash_prefix(get_hash(src_code.strip()))) + return os.path.join(extension_config.CONFIG_TORCHSIM_DUMP_PATH, hash_prefix(get_hash(src_code.strip()))) @dataclasses.dataclass class MLIRBenchmarkRequest(): @@ -49,6 +49,9 @@ def __init__( self.extra_args = extra_args #self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so") + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" + def make_run_fn( self, input_tensors: torch.Tensor, output_tensors: torch.Tensor ) -> Callable[[], None]: @@ -58,20 +61,32 @@ def make_run_fn( # Check already cached result. write_path = get_write_path(self.source_code) key, _ = write(self.source_code, "mlir", specified_dir=write_path) - result_path = os.path.join(extension_config.CONFIG_TORCHSIM_DUMP_PATH, "outputs", hash_prefix(key), "togsim_result/0") - if os.path.exists(result_path): - result = TOGSimulator.get_result_from_file(result_path) - def cached_run_fn(*args, **kwargs): - return result - return cached_run_fn + result_dir = os.path.join(extension_config.CONFIG_TORCHSIM_DUMP_PATH, hash_prefix(key), "togsim_result") + + # Find the most recent .log file in the result directory + if os.path.exists(result_dir) and os.path.isdir(result_dir): + log_files = [f for f in os.listdir(result_dir) if f.endswith('.log')] + if log_files: + # Sort by modification time, get the most recent file + log_files_with_time = [ + (f, os.path.getmtime(os.path.join(result_dir, f))) + for f in log_files + ] + log_files_with_time.sort(key=lambda x: x[1], reverse=True) + latest_log_file = log_files_with_time[0][0] + result_path = os.path.join(result_dir, latest_log_file) + result = TOGSimulator.get_result_from_file(result_path) + def cached_run_fn(*args, **kwargs): + return result + return cached_run_fn # Run a candidate code run_method = custom_async_compile.mlir( self.source_code, vectorlane_size=self.extra_args["vector_lane"], - loop_size=None, spad_info=self.extra_args["spad_info"], + loop_size=self.extra_args["loop_size"], spad_info=self.extra_args["spad_info"], vlen=self.extra_args["vlen"], arg_attributes=self.extra_args["arg_attributes"], - origins="Unknown", silent_mode=True, - validate=self.extra_args['validate'], autotune=self.extra_args['autotune']) + origins=self.extra_args["origins"], silent_mode=True, + autotune=self.extra_args['autotune']) args = [ tensor @@ -84,5 +99,6 @@ def cached_run_fn(*args, **kwargs): *args, ) - def __str__(self) -> str: - return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" \ No newline at end of file + def update_workspace_size(self) -> None: + # FIXME: Not implemented yet. Checkout torch/_inductor/codegen/rocm/rocm_benchmark_request.py + return \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index 178ea987..c5fd902f 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -26,20 +26,20 @@ {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} {% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} %c0 = arith.constant 0 : index {{ kernel.def_local_vars(indent_size=2) }} affine.for %index0 = 0 to {{ B }} { affine.for %index1 = 0 to {{ M }} step {{ TILE_M }} { affine.for %index2 = 0 to {{ N }} step {{ TILE_N }} { - %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> + %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1> + %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1> + %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1> {% if Bias -%} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_N], indent_size=8) }} {%- else -%} - affine.vector_store %v0, %Y_buffer[0, 0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %Y_buffer[0, 0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} affine.for %index3 = 0 to {{ K }} step {{ TILE_K }} { @@ -74,20 +74,20 @@ {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} {% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} %c0 = arith.constant 0 : index {{ kernel.def_local_vars(indent_size=2) }} affine.for %index0 = 0 to {{ B }} { affine.for %index1 = 0 to {{ M }} step {{ TILE_M }} { affine.for %index2 = 0 to {{ N }} step {{ TILE_N }} { - %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> + %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, 1> + %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> + %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> {% if Bias -%} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_N], indent_size=8) }} {%- else -%} - affine.vector_store %v0, %Y_buffer[0, 0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %Y_buffer[0, 0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} affine.for %index3 = 0 to {{ K }} step {{ TILE_K }} { {{kernel.load_input(indent_size=10)}} @@ -120,21 +120,21 @@ {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} {% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} %c0 = arith.constant 0 : index {{ kernel.def_local_vars(indent_size=2) }} affine.for %index0=0 to {{ B }} { affine.for %index2 = 0 to {{ N }} step {{ TILE_N }} { affine.for %index1 = 0 to {{ M }} step {{ TILE_M }} { - %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_N }}x{{ TILE_M }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> + %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, 1> + %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> + %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_N }}x{{ TILE_M }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> {% if Bias -%} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_N], indent_size=8) }} // Why not N,M? Currently, dma-fine-grained pass assume M->N order... {%- else -%} - affine.vector_store %v0, %Y_buffer[0, 0, 0] : memref<1x{{ TILE_N }}x{{ TILE_M }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %Y_buffer[0, 0, 0] : memref<1x{{ TILE_N }}x{{ TILE_M }}x{{DATA_STYPE}}, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} affine.for %index3 = 0 to {{ K }} step {{ TILE_K }} { {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_K], indent_size=10) }} @@ -154,6 +154,9 @@ class MLIRBMMTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None): super().__init__("kernel", input_nodes, layout, input_reorder) + self.support_epilogue_fusion = True + self.support_prologue_fusion = True + self.support_reduction_fusion = True def render(self, kernel: MLIRTemplateKernel, @@ -163,8 +166,9 @@ def render(self, tile_info = None, **kwargs): X, W, Y, Bias, W_tensor, X_tensor, B, M, N, K, n_extra_node, n_prologue_node = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) + precision_bytes = mlir_common.get_dtype_nbytes(X.get_dtype()) if tile_info is None: - TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_extra_node, 0, n_prologue_node)[0] + TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_extra_node, 0, n_prologue_node, precision_bytes)[0] else: TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info @@ -234,6 +238,7 @@ def render(self, else: Bias_idx = None + data_stype = mlir_common.DTYPE_TO_MLIR[X.get_dtype()] kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -242,7 +247,7 @@ def render(self, SUB_TILE_M=SUB_TILE_M, SUB_TILE_N=SUB_TILE_N, SUB_TILE_K=SUB_TILE_K, - DATA_STYPE="f32", + DATA_STYPE=data_stype, X = X, W = W,Y = Y, Bias = Bias, X_idx = X_idx, W_idx = W_idx, @@ -316,6 +321,12 @@ def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): X, W = self.input_nodes[0], self.input_nodes[1] Y = self.output_node Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + dtype_infos = [("X", X.get_dtype()), ("W", W.get_dtype()), ("Y", Y.get_dtype())] + if Bias is not None: + dtype_infos.append(("Bias", Bias.get_dtype())) + if len({dtype for _, dtype in dtype_infos}) != 1: + dtype_desc = ", ".join(f"{name}={dtype}" for name, dtype in dtype_infos) + raise NotImplementedError(f"Mixed dtype BMM is not implemented yet ({dtype_desc})") W_tensor = empty_strided(W.layout.size, W.layout.stride) X_tensor = empty_strided(X.layout.size, X.layout.stride) @@ -340,10 +351,11 @@ def get_tile_candidates(self, prologue_nodes: Optional[List[IRNode]] = None, **kwargs): X, W, Y, Bias, W_tensor, X_tensor, B, M, N, K, n_extra_node, n_prologue_node = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) - return self.select_tile(kernel, M, N, K, n_extra_node, 0, n_prologue_node) + precision_bytes = mlir_common.get_dtype_nbytes(X.get_dtype()) + return self.select_tile(kernel, M, N, K, n_extra_node, 0, n_prologue_node, precision_bytes) - def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node): - tile_candidates = kernel.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node) + def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node, precision_bytes): + tile_candidates = kernel.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node, precision_bytes=precision_bytes) for idx, (TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane) or n_prologue_node else kernel.vector_lane SUB_TILE_N = TILE_N # if (TILE_N < kernel.vector_lane) or prologue_nodes else kernel.vector_lane diff --git a/PyTorchSimFrontend/mlir/mlir_caller_codegen.py b/PyTorchSimFrontend/mlir/mlir_caller_codegen.py index dff6b0fd..7c842272 100644 --- a/PyTorchSimFrontend/mlir/mlir_caller_codegen.py +++ b/PyTorchSimFrontend/mlir/mlir_caller_codegen.py @@ -1,4 +1,5 @@ import os +import math import subprocess import shlex import re @@ -58,7 +59,11 @@ def load_arg(self): if self.is_in_arg(arg_attribute[0]): argv_idx = self.get_argv_idx() if arg_name not in self.load_args else self.load_args[arg_name] self.load_args[arg_name] = argv_idx - self.writeline(f'if(load_arg(c_{arg_name}, sizeof(c_{arg_name}), argv[{argv_idx}]) == -1){self.open_bracket}') + ctype = DTYPE_TO_C[arg_attribute[1]] + elem_count = arg_attribute[2] + size_expr = f'({elem_count}ULL * sizeof({ctype}))' + + self.writeline(f'if(load_arg(c_{arg_name}, {size_expr}, argv[{argv_idx}]) == -1){self.open_bracket}') with self.code.indent(): self.writeline(f'return -1{self.ending}') self.writeline(self.closed_bracket) @@ -67,7 +72,10 @@ def dump_arg(self): for arg_name, arg_attribute in self.arg_attributes: if self.is_out_arg(arg_attribute[0]): argv_idx = self.get_argv_idx() if not self.is_inout_arg(arg_attribute[0]) else self.load_args[arg_name] - self.writeline(f'if(dump_arg(c_{arg_name}, sizeof(c_{arg_name}), argv[{argv_idx}]) == -1){self.open_bracket}') + ctype = DTYPE_TO_C[arg_attribute[1]] + elem_count = arg_attribute[2] + size_expr = f'({elem_count}ULL * sizeof({ctype}))' + self.writeline(f'if(dump_arg(c_{arg_name}, {size_expr}, argv[{argv_idx}]) == -1){self.open_bracket}') with self.code.indent(): self.writeline(f'return -1{self.ending}') self.writeline(self.closed_bracket) @@ -84,29 +92,25 @@ def generate_kernel_declare(self): def generate_args_define(self): name_set = set() if self.validation: - self.writeline(f'int padding[0x100000]{self.ending}') # FIXME. For pooling operation... Some pooling layer use negative offset + self.writeline(f"int* padding = malloc(0x100000ULL * sizeof(int)){self.ending}") for arg_name, (_, arg_type, arg_size, arg_sizes, arg_stride) in self.arg_attributes: if not arg_name in name_set: - if self.validation: - self.writeline(f'{DTYPE_TO_C[arg_type]} c_{arg_name}[{arg_size}ULL]{self.ending}') + if torch.is_floating_point(torch.tensor([], dtype=arg_type)): + bits = torch.finfo(arg_type).bits + elif arg_type == torch.bool: + bits = 8 else: - if torch.is_floating_point(torch.tensor([], dtype=arg_type)): - bits = torch.finfo(arg_type).bits - elif arg_type == torch.bool: - bits = 8 - else: - bits = torch.iinfo(arg_type).bits - self.writeline(f'{DTYPE_TO_C[arg_type]}* c_{arg_name} = malloc({arg_size * bits // 8}ULL){self.ending}') + bits = torch.iinfo(arg_type).bits + buffer_size = int(math.ceil(arg_size * bits // 8 / 64) * 64) * 2 # Round up to 64 bytes + Add some padding for safety + self.writeline(f'{DTYPE_TO_C[arg_type]}* c_{arg_name} = malloc({buffer_size}ULL){self.ending}') name_set.add(arg_name) self.writeline(self.newline) def generate_main(self): - if self.validation: - self.generate_args_define() - self.writeline(f'{self.newline}int main(int argc, char *argv[]) {self.open_bracket}{self.newline}') with self.code.indent(): if self.validation: + self.generate_args_define() self.load_arg() self.writeline(self.newline) else: @@ -178,22 +182,18 @@ def add_extention(self, name, extension): def compile_wih_kernel(self, write_path, llvm_name, wrapper_name, binary_name, link_option=""): main_path = os.path.join(write_path, self.add_extention(wrapper_name, 'c')) main_obj_path = os.path.join(write_path, self.add_extention(wrapper_name, 'o')) - kernel_path = os.path.join(write_path, self.add_extention(llvm_name, 's')) kernel_obj_path = os.path.join(write_path, self.add_extention(llvm_name, 'o')) main_compile = f'riscv64-unknown-elf-gcc -march=rv64gcv -c {main_path} -o {main_obj_path}' - kernel_compile = f'clang -c --target="riscv64" -march=rv64gcv -O2 -nostdlib {kernel_path} -o {kernel_obj_path}' target = os.path.join(write_path, binary_name) link = f'riscv64-unknown-elf-gcc -march=rv64gcv {main_obj_path} {kernel_obj_path} -o {target} -lm {link_option}' main_compile_cmd = shlex.split(main_compile) - kernel_compile_cmd = shlex.split(kernel_compile) link_cmd = shlex.split(link) try: subprocess.check_call(main_compile_cmd) - subprocess.check_call(kernel_compile_cmd) subprocess.check_call(link_cmd) except subprocess.CalledProcessError as e: print("Command failed with exit code", e.returncode) diff --git a/PyTorchSimFrontend/mlir/mlir_cat_template.py b/PyTorchSimFrontend/mlir/mlir_cat_template.py new file mode 100644 index 00000000..7abdfee6 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_cat_template.py @@ -0,0 +1,367 @@ +from typing import List, Optional, Set +import math +import itertools + +import sympy +from torch._inductor.ir import IRNode + +from PyTorchSimFrontend.mlir import mlir_common +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel + + +TEMPLATE = r""" +{{kernel.def_global_vars()}} +func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=INPUT_NAMES, outputs=[Y], names_str=NAMES_STR, input_reorder=input_reorder)}} { +{%- for buffer_name, tile_desc in UNIQUE_BUFFER_TILE_DESCS.items() %} + {{ kernel.def_sram_buffer(buffer_name, tile_desc, indent_size=2) }} +{%- endfor %} + {{ kernel.def_local_vars(indent_size=2) }} + + affine.for %cat_block = 0 to 1 step 1 { +{%- for d in range(RANK-1) %} + affine.for %index{{ OUTPUT_DIM[d] }} = 0 to {{ OUTPUT_SIZES[d] }} step {{ TILE_SIZES[d] }} { +{%- endfor %} +{%- for i in range(NUM_INPUTS) %} + // Input tensor{{ i }} + affine.for %index_local{{ DIM }}_{{ i }} = 0 to {{ INPUTS[i].sizes[DIM] }} step {{ INPUTS[i].tile_size_dim }} { + %index{{ DIM }}_{{ i }} = affine.apply affine_map<(d0) -> (d0 + {{ INPUTS[i].cum_offset }})> (%index_local{{ DIM }}_{{ i }}) + %input_dram_offset_{{ i }} = affine.apply {{ INPUTS[i].offset_map }}({{ INPUTS[i].offset_vars }}) + %output_dram_offset_{{ i }} = affine.apply {{ OUTPUTS[i].offset_map }}({{ OUTPUTS[i].offset_vars }}) + {{ kernel.def_dma_op("MVIN", INPUTS[i].dram_name, [], INPUTS[i].tile_desc, indent_size=INDENT_SIZE, dram_stride=INPUTS[i].dram_strides, dram_offset="input_dram_offset_" ~ i) }} + {{ kernel.def_dma_op("MVOUT", "Y", [], OUTPUTS[i].tile_desc, indent_size=INDENT_SIZE, dram_stride=OUTPUTS[i].dram_strides, dram_offset="output_dram_offset_" ~ i) }} + } { inner_loop=true } +{%- endfor %} + +{%- for d in range(RANK-1) %} + } { outer_loop=true } +{%- endfor %} + } { outer_loop=true } + return +} +""" + + +class MLIRCatTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, dim): + super().__init__("kernel", input_nodes, layout) + self.dim = dim + + def render( + self, + kernel: MLIRTemplateKernel, + template_buffer_node=None, + epilogue_nodes: Optional[List[IRNode]] = None, + tile_info=None, + **kwargs, + ): + input_nodes = self.input_nodes + y = self.output_node + dtype_infos = [("Y", y.get_dtype())] + [(f"X{i}", x.get_dtype()) for i, x in enumerate(input_nodes)] + if len({dtype for _, dtype in dtype_infos}) != 1: + dtype_desc = ", ".join(f"{name}={dtype}" for name, dtype in dtype_infos) + raise NotImplementedError(f"Mixed dtype Cat is not implemented yet ({dtype_desc})") + precision_bytes = mlir_common.get_dtype_nbytes(y.get_dtype()) + num_inputs = len(input_nodes) + rank = len(y.get_size()) + + input_sizes = [x.get_size() for x in input_nodes] + output_sizes = [sz for d, sz in enumerate(y.get_size()) if d != self.dim] + output_dim = [d for d, _ in enumerate(y.get_size()) if d != self.dim] + output_strides = y.get_layout().stride + + tile_sizes = list(tile_info) if tile_info is not None else [1] * len(output_sizes) + excluded_dims = self._compute_excluded_dims(tile_sizes) + + input_tile_sizes_dim = self._calculate_input_tile_sizes( + kernel, input_sizes, tile_sizes, num_inputs, rank, precision_bytes + ) + buffer_name_to_template_name, input_dram_names = self._build_buffer_mapping(input_nodes) + input_tile_descs, output_tile_descs, unique_tile_descs = self._build_tile_descriptors( + kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, + input_dram_names, y, excluded_dims=excluded_dims + ) + (input_offset_maps, input_offset_var_strs, input_dram_strides, + output_offset_maps, output_offset_var_strs, output_dram_strides, + cumulative_offsets) = self._build_dma_info( + input_nodes, input_sizes, output_strides, input_tile_descs, output_tile_descs, + rank, num_inputs, excluded_dims=excluded_dims + ) + + unique_buffer_tile_descs = { + buffer_name_to_template_name[name]: desc + for name, desc in unique_tile_descs.items() + } + names_str = ", ".join(input_dram_names + ["Y"]) + indent_size = 2 + (rank - 1) * 2 + 4 + + inputs_info = [ + dict( + dram_name = input_dram_names[i], + sizes = input_sizes[i], + tile_size_dim= input_tile_sizes_dim[i], + tile_desc = input_tile_descs[i], + offset_map = input_offset_maps[i], + offset_vars = input_offset_var_strs[i], + dram_strides = input_dram_strides[i], + cum_offset = cumulative_offsets[i], + ) + for i in range(num_inputs) + ] + outputs_info = [ + dict( + tile_desc = output_tile_descs[i], + offset_map = output_offset_maps[i], + offset_vars = output_offset_var_strs[i], + dram_strides = output_dram_strides[i], + ) + for i in range(num_inputs) + ] + + kernel.render_options = dict( + KERNEL_NAME = self.name, + kernel = kernel, + NUM_INPUTS = num_inputs, + NAMES_STR = names_str, + Y = y, + INPUT_NAMES = input_nodes, + RANK = rank, + DIM = self.dim, + OUTPUT_SIZES = output_sizes, + OUTPUT_DIM = output_dim, + TILE_SIZES = tile_sizes, + UNIQUE_BUFFER_TILE_DESCS = unique_buffer_tile_descs, + INPUTS = inputs_info, + OUTPUTS = outputs_info, + INDENT_SIZE = indent_size, + input_reorder = self.input_reorder, + ) + + return self._template_from_string(TEMPLATE).render(**kernel.render_options) + + def get_tile_candidates( + self, + kernel: MLIRTemplateKernel, + template_buffer_node=None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs, + ): + """Generate tile candidates for cat operation. Concat dimension always has tile size 1.""" + if template_buffer_node is not None: + self.output_node = template_buffer_node + + y = self.output_node + dtype_infos = [("Y", y.get_dtype())] + [(f"X{i}", x.get_dtype()) for i, x in enumerate(self.input_nodes)] + if len({dtype for _, dtype in dtype_infos}) != 1: + dtype_desc = ", ".join(f"{name}={dtype}" for name, dtype in dtype_infos) + raise NotImplementedError(f"Mixed dtype Cat is not implemented yet ({dtype_desc})") + precision_bytes = mlir_common.get_dtype_nbytes(y.get_dtype()) + num_inputs = len(self.input_nodes) + output_sizes = [sz for d, sz in enumerate(y.get_size()) if d != self.dim] + + if not output_sizes: + return [[1]] + + max_tile_total = kernel.spad_info["spad_size"] // ( + kernel.vector_lane * precision_bytes * 2 * num_inputs + ) + + dim_tile_candidates = [] + for dim_size in output_sizes: + max_tile = min(dim_size, max_tile_total) + candidates = set() + for mult in range(1, max_tile // kernel.vector_lane + 1): + t = mult * kernel.vector_lane + if t <= dim_size and dim_size % t == 0: + candidates.add(t) + if max_tile > 0: + for exp in range(int(math.log2(max_tile)) + 1): + t = 2 ** exp + if t <= dim_size and dim_size % t == 0: + candidates.add(t) + candidates.add(dim_size) # dim_size always divides itself + dim_tile_candidates.append(sorted(candidates)[:5]) + + tile_candidates = [ + list(combo) + for combo in itertools.product(*dim_tile_candidates) + if math.prod(combo) * (num_inputs + 1) * precision_bytes + <= kernel.spad_info["spad_size"] * kernel.vector_lane + ] + + if not tile_candidates: + tile_candidates = [[1] * len(output_sizes)] + + tile_candidates.sort(key=lambda x: -math.prod(x)) + return tile_candidates[:4] + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _compute_excluded_dims(self, tile_sizes: list) -> list: + """Return non-tiled dimension indices when rank exceeds the 4-dim limit.""" + max_tiled = 3 + if len(tile_sizes) <= max_tiled: + return [] + sorted_dims = sorted(enumerate(tile_sizes), key=lambda x: x[1], reverse=True) + excluded = [idx for idx, _ in sorted_dims[max_tiled:]] + for idx in excluded: + tile_sizes[idx] = 1 + return excluded + + def _calculate_input_tile_sizes(self, kernel, input_sizes, tile_sizes, num_inputs, rank, precision_bytes): + """Calculate tile sizes along the concat dimension for each input.""" + non_dim_tile_elements = math.prod(tile_sizes) if tile_sizes else 1 + max_spad_per_input = kernel.spad_info["spad_size"] * kernel.vector_lane // 2 + extra_concat = math.ceil(max_spad_per_input / (non_dim_tile_elements * precision_bytes)) - num_inputs + + input_tile_sizes_dim = [] + for i in range(num_inputs): + if extra_concat > 0 and non_dim_tile_elements > 0: + tile_dim = min(input_sizes[i][self.dim], extra_concat) + extra_concat -= tile_dim + else: + tile_dim = 1 + input_tile_sizes_dim.append(tile_dim) + return input_tile_sizes_dim + + def _build_buffer_mapping(self, input_nodes): + """Map actual buffer names to short template names (X0, X1, ...).""" + name_map = {} + template_names = [] + for x in input_nodes: + actual = x.get_name() + template = name_map.setdefault(actual, f"X{len(name_map)}") + template_names.append(template) + return name_map, template_names + + def _build_tile_descriptors( + self, kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, + input_buffer_names, output_node, excluded_dims=None + ): + """Build tile descriptors for every input (and its paired output).""" + if excluded_dims is None: + excluded_dims = set() + + def make_tile_desc(tile_sz, vector_lane, name, offset): + desc = mlir_common.MLIRMultiDimTile( + tile_sz, vector_lane, + vlane_split_axis=len(tile_sz) - 1, + vlane_stride=1 + ) + desc.set_tile_size(tile_sz) + desc.set_name(name) + desc.offset = offset + return desc + + output_offset = output_node.get_layout().offset + input_tile_descs, output_tile_descs, unique_tile_descs = [], [], {} + + for i, x in enumerate(input_nodes): + # Collect tile sizes for tiled dimensions only (skip excluded non-concat dims) + tile_sz = [] + tile_idx = 0 + for d in range(rank): + if d != self.dim: + if tile_idx not in excluded_dims: + tile_sz.append(tile_sizes[tile_idx]) + tile_idx += 1 + else: + tile_sz.append(input_tile_sizes_dim[i]) + + sram_name = f"{input_buffer_names[i].lower()}_cat_tile" + input_tile_descs.append(make_tile_desc(tile_sz, kernel.vector_lane, sram_name, x.get_layout().offset)) + output_tile_descs.append(make_tile_desc(tile_sz, kernel.vector_lane, sram_name, output_offset)) + + actual_name = x.get_name() + if actual_name not in unique_tile_descs: + unique_tile_descs[actual_name] = input_tile_descs[-1] + + return input_tile_descs, output_tile_descs, unique_tile_descs + + def _build_dma_info( + self, input_nodes, input_sizes, output_strides, + input_tile_descs, output_tile_descs, + rank, num_inputs, excluded_dims=None + ): + """Build per-input DRAM offset affine maps and tile strides. + + Three stride concepts are maintained: + + * layout_strides (internal) - raw DRAM buffer strides for every rank + dimension, used to compute the flat base-address affine map. + These reflect how the tensor is physically laid out in DRAM. + * dram_strides (returned, ``def_dma_op dram_stride=``) - stride in + DRAM per *tiled* dimension (excluded dims removed). The DMA engine + uses these to walk DRAM when loading/storing a tile. + * sram_strides (inside ``def_dma_op``, from tile_desc) - stride in + SRAM per tiled dimension. The DMA engine uses these to place data + into the SRAM tile buffer. + + Returns: + input_offset_maps, input_offset_var_strs, input_dram_strides, + output_offset_maps, output_offset_var_strs, output_dram_strides, + cumulative_offsets + """ + if excluded_dims is None: + excluded_dims = set() + + def make_affine_map(idx_syms, strides, layout_offset): + terms = [] + for j, s in enumerate(strides): + s = int(s) + if s == 1: + terms.append(f"d{j}") + elif s != 0: + terms.append(f"d{j} * {s}") + try: + off = int(layout_offset) + except (TypeError, ValueError): + off = 0 + if off: + terms.append(str(off)) + dim_str = ", ".join(f"d{j}" for j in range(len(idx_syms))) + return f"affine_map<({dim_str}) -> ({' + '.join(terms) if terms else '0'})>" + + cumulative_offsets = [0] + for i in range(num_inputs - 1): + cumulative_offsets.append(cumulative_offsets[-1] + input_sizes[i][self.dim]) + + input_offset_maps, input_offset_var_strs, input_dram_strides = [], [], [] + output_offset_maps, output_offset_var_strs, output_dram_strides = [], [], [] + + for i, x in enumerate(input_nodes): + x_stride = x.get_layout().stride + in_syms, in_layout_strides, in_dram_strides = [], [], [] + out_syms, out_layout_strides, out_dram_strides = [], [], [] + tile_idx = 0 + + for d in range(rank): + if d != self.dim: + in_syms.append(sympy.Symbol(f"index{d}")) + in_layout_strides.append(int(x_stride[d])) + out_syms.append(sympy.Symbol(f"index{d}")) + out_layout_strides.append(int(output_strides[d])) + if tile_idx not in excluded_dims: + in_dram_strides.append(int(x_stride[d])) + out_dram_strides.append(int(output_strides[d])) + tile_idx += 1 + else: + in_syms.append(sympy.Symbol(f"index_local{self.dim}_{i}")) + in_layout_strides.append(int(x_stride[d])) + out_syms.append(sympy.Symbol(f"index{self.dim}_{i}")) + out_layout_strides.append(int(output_strides[d])) + in_dram_strides.append(int(x_stride[d])) + out_dram_strides.append(int(output_strides[d])) + + input_offset_maps.append(make_affine_map(in_syms, in_layout_strides, input_tile_descs[i].offset)) + input_offset_var_strs.append(", ".join(f"%{s}" for s in in_syms)) + input_dram_strides.append(in_dram_strides) + + output_offset_maps.append(make_affine_map(out_syms, out_layout_strides, output_tile_descs[i].offset)) + output_offset_var_strs.append(", ".join(f"%{s}" for s in out_syms)) + output_dram_strides.append(out_dram_strides) + + return (input_offset_maps, input_offset_var_strs, input_dram_strides, + output_offset_maps, output_offset_var_strs, output_dram_strides, + cumulative_offsets) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 6650f429..58d6a70d 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -1,17 +1,21 @@ import contextlib import sympy +import sys import re import os -import math from functools import reduce from operator import mul import torch +from typing import Optional from collections import defaultdict from concurrent.futures import ThreadPoolExecutor + +from PyTorchSimFrontend import extension_config from torch._dynamo.testing import rand_strided from torch._inductor.autotune_process import TensorMeta from torch._dynamo.utils import dynamo_timed from torch._inductor.codegen import cpp, wrapper, common, memory_planning +from torch._inductor.ir import GraphPartitionSignature from torch._inductor.virtualized import V, _ops as ops from torch._inductor.codecache import write_atomic from torch._inductor.utils import ( @@ -21,11 +25,16 @@ ) from torch.utils._sympy.functions import ModularIndexing, FloorDiv from PyTorchSimFrontend import extension_codecache -from PyTorchSimFrontend import extension_config from . import mlir_common from .mlir_common import LoopLevel, LoopNest +from .mlir_ops import ExtensionOverrides from PyTorchSimFrontend.mlir.mlir_autotune import MLIRBenchmarkRequest +# Configure logger for mlir_codegen_backend module +logger = extension_config.setup_logger() + +from Simulator.simulator import ProgressBar + def reduction_init(reduction_type, dtype): if dtype in cpp.DTYPE_LOWP_FP: # Since load promotes all half-precision inputs to float, the initial @@ -36,19 +45,9 @@ def reduction_init(reduction_type, dtype): if reduction_type == "prod": return float(1) if dtype.is_floating_point else int(1) if reduction_type in {"max", "argmax"}: - if dtype == torch.float32: - return f"0x{mlir_common.MLIR_INF['-inf']['f32']:x}" - elif dtype == torch.float64: - return f"0x{mlir_common.MLIR_INF['-inf']['f64']:x}" - else: - return "0.0" + return "-inf" if reduction_type in {"min", "argmin"}: - if dtype == torch.float32: - return f"0x{mlir_common.MLIR_INF['inf']['f32']:x}" - elif dtype == torch.float64: - return f"0x{mlir_common.MLIR_INF['inf']['f64']:x}" - else: - return "0.0" + return "inf" if reduction_type in {"welford_reduce"}: return f"0.0" raise AssertionError(reduction_type) @@ -63,26 +62,28 @@ def reduction_partial_combine_vec(reduction_type, vector_value, init_value): if reduction_type == "min": return ops.minimum(vector_value, init_value) if reduction_type == "any": - return ops.logical_and(vector_value, init_value) - raise AssertionError(reduction_type) - -def reduction_combine_vec(reduction_type, vector_value, init_value, axis, shape, reduced_shape): - if reduction_type == "sum": - return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" - if reduction_type == "prod": - return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" - if reduction_type == "max": - return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" - if reduction_type == "min": - return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" - if reduction_type == "any": - return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" + return ops.logical_or(vector_value, init_value) raise AssertionError(reduction_type) -class ExtensionWrapperCodegen(wrapper.WrapperCodeGen): +class ExtensionWrapperCodegen(wrapper.PythonWrapperCodegen): def __init__(self): super().__init__() + @classmethod + def create( + cls, + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[wrapper.PythonWrapperCodegen], + partition_signatures: Optional[GraphPartitionSignature] = None, + ): + if is_subgraph: + assert subgraph_name is not None and parent_wrapper is not None + return wrapper.SubgraphPythonWrapperCodegen( + subgraph_name, parent_wrapper, partition_signatures + ) + return cls() + def write_header(self): self.header.splice( f""" @@ -96,21 +97,28 @@ def write_header(self): from torch._inductor.hooks import run_intermediate_hooks from torch._inductor.utils import maybe_profile from torch._inductor.codegen.memory_planning import _align as align + from torch._inductor.async_compile import AsyncCompile from torch import device, empty, empty_strided from {extension_codecache.__name__} import CustomAsyncCompile - from PyTorchSimFrontend.extension_config import CONFIG_SRAM_BUFFER_PLAN, CONFIG_TOGSIM_EAGER_MODE + from PyTorchSimFrontend.extension_config import CONFIG_SRAM_BUFFER_PLAN, setup_logger from Simulator.simulator import TOGSimulator from PyTorchSimFrontend.extension_op import sparse_mm_dummy_stonne_outer from torch._inductor.select_algorithm import extern_kernels + # Configure logger for generated wrapper code + _logger = setup_logger("PyTorchSimFrontend.mlir.generated_wrapper") + aten = torch.ops.aten inductor_ops = torch.ops.inductor assert_size_stride = torch._C._dynamo.guards.assert_size_stride + assert_alignment = torch._C._dynamo.guards.assert_alignment alloc_from_pool = torch.ops.inductor._alloc_from_pool - reinterpret_tensor = torch.ops.aten._reinterpret_tensor + reinterpret_tensor = torch.ops.inductor._reinterpret_tensor custom_async_compile = CustomAsyncCompile() + async_compile = AsyncCompile() os.environ["TORCHSIM_LAST_COMPILED_MODULE"] = __file__ + _logger.info(f'Wrapper Codegen Path = {{__file__}}') """ ) self.header.splice( @@ -142,6 +150,7 @@ def device2host_memcpy(buffer): ) def write_prefix(self): + self.write_async_compile_wait() self.prefix.splice( """ def call(args): @@ -154,7 +163,7 @@ def call(args): self.prefix.writeline(f"{lhs} = args") self.prefix.writeline("args.clear()") - self.codegen_inputs(self.prefix, V.graph.graph_inputs) + self.codegen_inputs() self.codegen_input_size_asserts() self.codegen_sram_plan_prefix() @@ -174,35 +183,60 @@ def codegen_sram_plan_postfix(self, outputs): continue self.wrapper_call.writeline(f"sram_plan_postfix('{name}', {name})") - @dynamo_timed + def _generate_kernel_call_helper( + self, + kernel_name: str, + call_args, + *, + device=None, + triton=True, + arg_types=None, + raw_keys=None, + raw_args=None, + triton_meta=None, + graph_name="", + original_fxnode_name=None, + ): + device = device or V.graph.get_current_device_or_throw() + self.writeline(self.wrap_kernel_call(kernel_name, call_args)) + return + def generate(self, is_inference): result = IndentedBuffer() - result.splice(self.header) + # result.splice(self.header) with contextlib.ExitStack() as stack: stack.enter_context(self.wrapper_call.indent()) self.memory_plan_reuse() - for line in self.lines: - # Add buffer plan hook for dealloc - if isinstance(line, memory_planning.DeallocFromPoolLine): - self.wrapper_call.writeline(f"sram_plan_postfix('{line.node.get_name()}', {line.node.get_name()})") - elif isinstance(line, str) and "del" in line: - name = line.split(" ")[1] - self.wrapper_call.writeline(f"sram_plan_postfix('{name}', {name})") - - if isinstance(line, wrapper.MemoryPlanningLine): - line.codegen(self.wrapper_call) - else: - self.wrapper_call.writeline(line) - # Add buffer plan hook for alloc - if isinstance(line, memory_planning.AllocFromPoolLine) or isinstance(line, wrapper.AllocateLine): - self.wrapper_call.writeline(f"sram_plan_prefix('{line.node.get_name()}', {line.node.get_name()})") + with self.set_writeline(self.wrapper_call.writeline): + for line in self.lines: + # Add buffer plan hook for dealloc + if isinstance(line, memory_planning.DeallocFromPoolLine): + self.wrapper_call.writeline(f"sram_plan_postfix('{line.node.get_name()}', {line.node.get_name()})") + elif isinstance(line, str) and "del" in line: + name = line.split(" ")[1] + self.wrapper_call.writeline(f"sram_plan_postfix('{name}', {name})") + + if isinstance(line, wrapper.MemoryPlanningLine): + line.codegen(self.wrapper_call) + elif isinstance(line, wrapper.KernelCallLine): + self.wrapper_call.writeline(self.wrap_kernel_call(line.kernel_name, line.call_args)) + else: + if isinstance(line, wrapper.WrapperLine): + line.codegen(self.wrapper_call) + else: + self.wrapper_call.writeline(line) + # Add buffer plan hook for alloc + if isinstance(line, memory_planning.AllocFromPoolLine) or isinstance(line, wrapper.AllocateLine): + self.wrapper_call.writeline(f"sram_plan_prefix('{line.node.get_name()}', {line.node.get_name()})") output_refs = self.get_output_refs() self.codegen_sram_plan_postfix(output_refs) self.mark_output_type() self.generate_return(output_refs) - self.append_precomputed_sizes_to_prefix() + # self.append_precomputed_sizes_to_prefix() # FIXME: Need to replace append_precomputed_sizes_to_prefix() + result.splice(self.header) + self.finalize_prefix() result.splice(self.prefix) @@ -211,679 +245,13 @@ def generate(self, is_inference): self.generate_end(result) self.add_benchmark_harness(result) - return result.getvaluewithlinemap() + return ( + result.getvaluewithlinemap(), + self.kernel_declarations.getvaluewithlinemap(), + ) def memory_plan(self): self.lines = memory_planning.MemoryPlanner(self).plan(self.lines) -class ExtensionOverrides(common.OpOverrides): - # Binary element wise operations - @staticmethod - def custom_cast(operand, target_type, *args, var_info=None, **kwargs): - dtype = var_info[operand][1] - if dtype == "index": - ret = ops.index_cast(operand, target_type, var_info=var_info) - else: - ret = ops.to_dtype(operand, target_type, var_info=var_info) - return ret, var_info[ret] - - @staticmethod - def binary_elementwise_common(operand1, operand2, var_info): - operand1.bounds = operand1.bounds.unknown() - operand2.bounds = operand2.bounds.unknown() - op_type1 = var_info[operand1] - op_type2 = var_info[operand2] - # Tile size check - if op_type1[0] != op_type2[0]: - # Try to broad cast - lhs_tile_size, lhs_dtype = op_type1 - rhs_tile_size, rhs_dtype = op_type2 - if lhs_tile_size > rhs_tile_size: - operand2 = ops.broadcast(operand2, operand1, var_info=var_info) - op_type2 = var_info[operand2] - elif lhs_tile_size < rhs_tile_size: - operand1 = ops.broadcast(operand1, operand2, var_info=var_info) - op_type1 = var_info[operand1] - - # Data type check - if op_type1[1] != op_type2[1]: - if op_type1[1] == "index" or op_type1 == "index": - if op_type1[1] == "index": - operand1 = ops.index_cast(operand1, op_type2[1], var_info) - op_type1 = var_info[operand1] - if op_type2[1] == "index": - operand2 = ops.index_cast(operand2, op_type1[1], var_info) - op_type2 = var_info[operand2] - elif op_type1[1][0] == "i" and op_type2[1][0] == "f": - operand1 = ops.to_dtype(operand1, op_type2[1], var_info) - op_type1 = var_info[operand1] - elif op_type1[1][0] == "f" and op_type2[1][0] == "i": - operand2 = ops.to_dtype(operand2, op_type1[1], var_info) - op_type2 = var_info[operand2] - elif op_type1[1][0] == op_type2[1][0]: - if mlir_common.MLIR_TO_BIT[op_type1[1]] > mlir_common.MLIR_TO_BIT[op_type2[1]]: - operand2 = ops.ext(operand2, op_type1[1]) - op_type2 = var_info[operand2] - elif mlir_common.MLIR_TO_BIT[op_type1[1]] < mlir_common.MLIR_TO_BIT[op_type2[1]]: - operand1 = ops.ext(operand1, op_type2[1]) - op_type1 = var_info[operand1] - else: - raise NotImplementedError("Unsupported type converting") - - # Updated var info - tile_size = op_type1[0] - ret_type = op_type1[1] - return tile_size, ret_type, operand1, operand2 - - @staticmethod - def add(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - opcode = f'arith.add{ret_type[0]}' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - @staticmethod - def sub(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - opcode = f'arith.sub{ret_type[0]}' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - @staticmethod - def mul(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - opcode = f'arith.mul{ret_type[0]}' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - @staticmethod - def div(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - if ret_type[0] == "f": - opcode = f'arith.divf' - else: - opcode = f'arith.divui' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - @staticmethod - def truediv(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - if ret_type[0] == "f": - opcode = f'arith.divf' - else: - opcode = f'arith.divui' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - @staticmethod - def modular(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - if ret_type[0] == "f": - raise NotImplementedError("Not support remainder operation for floating point") - else: - opcode = f'arith.remui' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - @staticmethod - def minimum(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - if ret_type[0] == "f": - opcode = f'arith.minimumf' - else: - opcode = f'arith.minimumui' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - @staticmethod - def maximum(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - if ret_type[0] == "f": - opcode = f'arith.maximumf' - else: - opcode = f'arith.maximumui' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - @staticmethod - def to_dtype(operand, dst_mlir_dtype, *args, var_info=None, **kwargs): - src_mlir_dtype = var_info[operand][1] - if src_mlir_dtype == "index": - operand = ops.index_cast(operand, "i64", var_info=var_info) - src_mlir_dtype = var_info[operand][1] - - tile_size = var_info[operand][0] - if isinstance(dst_mlir_dtype, torch.dtype): - dst_mlir_dtype = mlir_common.DTYPE_TO_MLIR[dst_mlir_dtype] - dst_bits = mlir_common.MLIR_TO_BIT[dst_mlir_dtype] - src_bits = mlir_common.MLIR_TO_BIT[src_mlir_dtype] - shape = f"vector<{tile_size}x{dst_mlir_dtype}>" if tile_size > 1 else dst_mlir_dtype - src_shape = f"vector<{tile_size}x{src_mlir_dtype}>" if tile_size > 1 else src_mlir_dtype - if dst_mlir_dtype[0] == "i" and src_mlir_dtype[0] == "f": - return f"arith.fptoui %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] - if dst_mlir_dtype[0] == "f" and src_mlir_dtype[0] == "i": - return f"arith.uitofp %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] - if dst_mlir_dtype[0] == "i": - if dst_bits > src_bits: - return f"arith.extui %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] - elif dst_bits < src_bits: - return f"arith.trunc %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] - return f"arith.maximumi %{operand}, %{operand} : {shape}", [tile_size, dst_mlir_dtype] - elif dst_mlir_dtype[0] == "f": - if dst_bits > src_bits: - return f"arith.extf %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] - elif dst_bits < src_bits: - return f"arith.trunf %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] - return f"arith.maximumf %{operand}, %{operand} : {shape}", [tile_size, dst_mlir_dtype] - else: - raise NotImplementedError("Unsupported type for to_dtype ops") - - @staticmethod - def constant(value, src_type, *args, var_info=None, **kwargs): - if isinstance(src_type, torch.dtype): - src_type = mlir_common.DTYPE_TO_MLIR[src_type] - - if "inf" == str(value) or "-inf" == str(value) or "nan" == str(value): - value = f"0x{mlir_common.MLIR_INF[str(value)][src_type]:x}" - # if value represented by e notation, convert to float (ex 1e-3 -> 1.0e-3) - elif "e" in str(value): - value = format(float(value), ".20f") - elif src_type[0] == "f": - value = format(value, ".20f") - elif src_type[0] == "i": - value = int(value) - return f'arith.constant {value} : {src_type}', [1, src_type] - - @staticmethod - def alloc(size, src_type, *args, var_info=None, **kwargs): - return f"memref.alloc() : memref<{size}x{src_type}>", [size, src_type] - - @staticmethod - def extractelement(operand, idx, *args, var_info=None, **kwargs): - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f"vector.extract %{operand}[{idx}]: {dtype} from {shape}", [1, dtype] - - # transcendental functions - @staticmethod - def exp(operand, *args, var_info=None, **kwargs): - # Check scalar - op_type = var_info[operand] - if op_type[0] == 1: - val = ops.constant(0, op_type[1]) - var_info[val][0] = 4 - operand = ops.broadcast(operand, val) - val = ops.exp(operand) - result = ops.extractelement(val, 0) - return result, var_info[result] - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.exp %{operand} : {shape}', [tile_size, dtype] - - @staticmethod - def exp2(operand, *args, var_info=None, **kwargs): - # Hands-on part: implement exp2 using math.exp2 - # var_info = {operand: [tile_size, dtype]} - # Ex) var_info[operand] = [8, "f32"] - - ln2 = math.log(2) - coeff = ops.constant(ln2, "f32") - operand = ops.mul(operand, coeff) - return ops.exp(operand), var_info[operand] - - @staticmethod - def erf(operand, *args, var_info=None, **kwargs): - # Check scalar - op_type = var_info[operand] - if op_type[0] == 1: - val = ops.constant(0, op_type[1]) - var_info[val][0] = 4 - operand = ops.broadcast(operand, val) - val = ops.erf(operand) - result = ops.extractelement(val, 0) - return result, var_info[result] - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.erf %{operand} : {shape}', [tile_size, dtype] - - @staticmethod - def tanh(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - - # Check scalar - op_type = var_info[operand] - if op_type[0] == 1: - val = ops.constant(0, op_type[1]) - var_info[val][0] = 4 - operand = ops.broadcast(operand, val) - val = ops.tanh(operand) - result = ops.extractelement(val, 0) - return result, var_info[result] - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - - # Type check & auto cast - if dtype[0] != "f": - operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) - var_info[operand] = dtype - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.tanh %{operand} : {shape}', [tile_size, dtype] - - @staticmethod - def sin(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - - # Check scalar - op_type = var_info[operand] - if op_type[0] == 1: - val = ops.constant(0, op_type[1]) - var_info[val][0] = 4 - operand = ops.broadcast(operand, val) - val = ops.sin(operand) - result = ops.extractelement(val, 0) - return result, var_info[result] - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - - # Type check & auto cast - if dtype[0] != "f": - operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) - var_info[operand] = dtype - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.sin %{operand} : {shape}', [tile_size, dtype] - - @staticmethod - def cos(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - - # Check scalar - op_type = var_info[operand] - if op_type[0] == 1: - val = ops.constant(0, op_type[1]) - var_info[val][0] = 4 - operand = ops.broadcast(operand, val) - val = ops.cos(operand) - result = ops.extractelement(val, 0) - return result, var_info[result] - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - - # Type check & auto cast - if dtype[0] != "f": - operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) - var_info[operand] = dtype - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.cos %{operand} : {shape}', [tile_size, dtype] - - @staticmethod - def sqrt(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - - # Type check & auto cast - if dtype[0] != "f": - operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) - var_info[operand] = dtype - - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.sqrt %{operand} : {shape}', [tile_size, dtype] - - @staticmethod - def rsqrt(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - - # Type check & auto cast - if dtype[0] != "f": - operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) - var_info[operand] = dtype - - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.rsqrt %{operand} : {shape}', [tile_size, dtype] - - @staticmethod - def pow(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - # Type check & auto cast - if ret_type[0] != "f": - operand1, ret_type = ops.to_dtype(operand1, "f32", var_info=var_info) - var_info[operand1] = ret_type - - # Type check & auto cast - if ret_type[0] != "f": - operand2, ret_type = ops.to_dtype(operand2, "f32", var_info=var_info) - var_info[operand2] = ret_type - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f"math.pow{ret_type[0]} %{operand1}, %{operand2} : {shape}", [tile_size, ret_type] - - @staticmethod - def log(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - - # Type check & auto cast - if dtype[0] != "f": - operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) - var_info[operand] = dtype - - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.log %{operand} : {shape}', [tile_size, dtype] - - @staticmethod - def reciprocal(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - - # Type check & auto cast - if dtype[0] != "f": - operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) - var_info[operand] = dtype - - return ops.div(ops.constant(1.0, dtype), operand), [tile_size, dtype] - - @staticmethod - def ext(operand, dtype, *args, var_info=None, **kwargs): - op_type = var_info[operand] - shape = f"vector<{op_type[0]}x{op_type[1]}>" if op_type[0] > 1 else f"{op_type[1]}" - target_type = f"vector<{op_type[0]}x{dtype}>" if op_type[0] > 1 else f"{dtype}" - if op_type[0] == "f": - opcode = f'arith.extf' - else: - opcode = f'arith.extui' - return f'{opcode} %{operand} : {shape} to {target_type}', [op_type[0], dtype] - - # Logical operations - @staticmethod - def neg(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - - # Type check & auto cast - if dtype[0] != "f": - operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) - var_info[operand] = dtype - - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.negf %{operand} : {shape}', [tile_size, dtype] - - @staticmethod - def eq(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - if ret_type[0] == "f": - op_type = "arith.cmpf" - attribute = "oeq" - elif ret_type[0] == "i": - op_type = "arith.cmpi" - attribute = "eq" - else: - raise ValueError(f"Unsupported data type for 'eq' operation: {ret_type}") - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] - - @staticmethod - def ne(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - if ret_type[0] == "f": - op_type = "arith.cmpf" - attribute = "one" - elif ret_type[0] == "i": - op_type = "arith.cmpi" - attribute = "sne" - else: - raise ValueError(f"Unsupported data type for 'ne' operation: {ret_type}") - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] - - @staticmethod - def lt(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - if ret_type[0] == "f": - op_type = "arith.cmpf" - attribute = "olt" - elif ret_type[0] == "i": - op_type = "arith.cmpi" - attribute = "slt" - else: - raise ValueError(f"Unsupported data type for 'lt' operation: {ret_type}") - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] - - @staticmethod - def gt(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - if ret_type[0] == "f": - op_type = "arith.cmpf" - attribute = "ogt" - elif ret_type[0] == "i": - op_type = "arith.cmpi" - attribute = "sgt" - else: - raise ValueError(f"Unsupported data type for 'gt' operation: {ret_type}") - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] - - @staticmethod - def le(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - if ret_type[0] == "f": - op_type = "arith.cmpf" - attribute = "ole" - elif ret_type[0] == "i": - op_type = "arith.cmpi" - attribute = "sle" - else: - raise ValueError(f"Unsupported data type for 'le' operation: {ret_type}") - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] - - @staticmethod - def ge(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - if ret_type[0] == "f": - op_type = "arith.cmpf" - attribute = "oge" - elif ret_type[0] == "i": - op_type = "arith.cmpi" - attribute = "sge" - else: - raise ValueError(f"Unsupported data type for 'ne' operation: {ret_type}") - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] - - @staticmethod - def and_(operand1, operand2, *args, var_info=None, **kwargs): - op_type1 = var_info[operand1] - op_type2 = var_info[operand2] - - # Type check & auto cast - if op_type1[1][0] != "i": - operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) - var_info[operand1] = dtype - - # Type check & auto cast - if op_type2[1][0] != "i": - operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) - var_info[operand2] = dtype - - ret_type = op_type1[1] - tile_size = op_type1[0] - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'arith.andi %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - @staticmethod - def or_(operand1, operand2, *args, var_info=None, **kwargs): - op_type1 = var_info[operand1] - op_type2 = var_info[operand2] - - # Type check & auto cast - if op_type1[1][0] != "i": - operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) - var_info[operand1] = dtype - - # Type check & auto cast - if op_type2[1][0] != "i": - operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) - var_info[operand2] = dtype - - ret_type = op_type1[1] - tile_size = op_type1[0] - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'arith.ori %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - @staticmethod - def xor(operand1, operand2, *args, var_info=None, **kwargs): - op_type1 = var_info[operand1] - op_type2 = var_info[operand2] - - # Type check & auto cast - if op_type1[1][0] != "i": - operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) - var_info[operand1] = dtype - - # Type check & auto cast - if op_type2[1][0] != "i": - operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) - var_info[operand2] = dtype - - ret_type = op_type1[1] - tile_size = op_type1[0] - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'arith.xori %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - - - @staticmethod - def logical_and(operand1, operand2, *args, var_info=None, **kwargs): - op_type = var_info[operand1] - # Type check & auto cast - if op_type[1] != "i1": - raise NotImplementedError("Logical operation with not bool data type") - return ExtensionOverrides.and_(operand1, operand2, *args, var_info=var_info, **kwargs) - - @staticmethod - def logical_not(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - - ret_type = op_type[1] - tile_size = op_type[0] - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - const_one = ops.constant(0, ret_type) - const_one = ops.broadcast(const_one, operand, var_info=var_info) - ret = ops.eq(operand,const_one) - return ret, [tile_size, var_info[ret]] - - @staticmethod - def logical_or(operand1, operand2, *args, var_info=None, **kwargs): - op_type = var_info[operand1] - # Type check & auto cast - if op_type[1] != "i1": - raise NotImplementedError("Logical operation with not bool data type") - return ExtensionOverrides.or_(operand1, operand2, *args, var_info=var_info, **kwargs) - - @staticmethod - def logical_xor(operand1, operand2, *args, var_info=None, **kwargs): - op_type = var_info[operand1] - # Type check & auto cast - if op_type[1] != "i1": - raise NotImplementedError("Logical operation with not bool data type") - return ExtensionOverrides.xor(operand1, operand2, *args, var_info=var_info, **kwargs) - - @staticmethod - def relu(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - tile_size = op_type[0] - ret_type = "f32" - return ops.maximum(operand, ops.constant(0.0, "f32")), [tile_size, ret_type] - - @staticmethod - def sigmoid(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - tile_size = op_type[0] - ret_type = "f32" - one = ops.constant(1, "f32") - return ops.truediv(one, ops.add(one, ops.exp(ops.neg(operand)))), [tile_size, ret_type] - - # Special operaitons - @staticmethod - def where(condition, operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - cond_type = var_info[condition] - if cond_type[0] < tile_size: - condition = ops.broadcast(condition, operand1, var_info=var_info) - elif cond_type[0] > tile_size: - operand1 = ops.broadcast(operand1, condition, var_info=var_info) - operand2 = ops.broadcast(operand2, condition, var_info=var_info) - tile_size, ret_type = var_info[operand1] - - shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - cond_shape = f"vector<{tile_size}xi1>," if tile_size > 1 else "" - return f"arith.select %{condition}, %{operand1}, %{operand2} : {cond_shape} {shape}", [tile_size, ret_type] - - - @staticmethod - def masked(mask, body, other, *args, var_info=None, tile_size=16, dtype="f32", ninf_declared=False, **kwargs): - result = body() - val = ops.constant(other, dtype, *args, **kwargs) - result = ops.where(mask, result, val) - return result, var_info[result] - - @staticmethod - def index_cast(operand, target_type, *args, var_info=None, **kwrags): - op_type = var_info[operand] - src_shape = f"vector<{op_type[0]}x{op_type[1]}>" if op_type[0] > 1 else op_type[1] - des_shape = f"vector<{op_type[0]}x{target_type}>" if op_type[0] > 1 else target_type - return f"arith.index_cast %{operand} : {src_shape} to {des_shape}", [op_type[0], target_type] - - @staticmethod - def broadcast_unflat(operand1, operand2, *args, var_info=None, **kwargs): - op_type1 = var_info[operand1] - op_type2 = var_info[operand2] - src_shape = f"vector<{op_type1[0]}x{op_type1[1]}>"# if op_type1[0] > 1 else op_type1[1] - des_shape = f"vector<{op_type2[0]//op_type1[0]}x{op_type1[0]}x{op_type1[1]}>"# if op_type2[0] > 1 else op_type1[1] # Use tile size only - - expand = f"vector.broadcast %{operand1} : {src_shape} to {des_shape}" - return expand, [op_type2[0], op_type1[1]] - - @staticmethod - def broadcast(operand1, operand2, *args, var_info=None, **kwargs): - op_type1 = var_info[operand1] - op_type2 = var_info[operand2] - src_shape = f"vector<{op_type1[0]}x{op_type1[1]}>" if op_type1[0] > 1 else op_type1[1] - des_shape = f"vector<{op_type2[0]}x{op_type1[1]}>" # if op_type2[0] > 1 else op_type1[1] # Use tile size only - - # Special case for length 2 vector. We used this vector to avoid scalar operations... - if op_type1[0] != 1 and op_type2[0] % op_type1[0] == 0: - unflat_operand = ops.broadcast_unflat(operand1, operand2) - unflat_shape = f"vector<{op_type2[0]//op_type1[0]}x{op_type1[0]}x{op_type1[1]}>" - expand = f"vector.shape_cast %{unflat_operand} : {unflat_shape} to {des_shape}" - elif op_type1[0] == 1: - expand = f"vector.broadcast %{operand1} : {src_shape} to {des_shape}" - else: - raise NotImplementedError("Not supporting broadcast type...") - return expand, [op_type2[0], op_type1[1]] RTYPE_TO_MLIR = { "sum": "add", @@ -918,7 +286,13 @@ def __init__(self, kernel_group, reason=None): self.gem5_header = IndentedBuffer() self.header.writeline("#include ") self.header.writeline("#include ") - self.header.writeline("void* __wrap_malloc(size_t size) { return sbrk(size); }") + self.header.writeline("#include ") + self.header.writeline("void* __wrap_malloc(size_t size) {") # Align to 512 bytes + self.header.writeline(" size_t aligned = (size + 511UL) & ~511UL;") + self.header.writeline(" void *p = sbrk(aligned);") + #self.header.writeline(' fprintf(stderr, "[SPIKE][__wrap_malloc] addr=%p size=%zu (req=%zu)\\n", p, aligned, size);') + self.header.writeline(" return p;") + self.header.writeline("}") self.header.writeline("void __wrap_free(void *ptr) { return; }") self.reduction_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc") self.spad_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="spad") @@ -946,9 +320,12 @@ def __init__(self, kernel_group, reason=None): self.reduce_iterator = {} self.spad_buffer_dict = dict() self.base_vector_initialized = False + self.loop_size = None def reset(self, reason): + save = self.exit_stack, self._nested_context_depth self.__init__(self.kernel_group, reason=reason) + self.exit_stack, self._nested_context_depth = save # padding type 0: zero-padding 1: negative-padding(-inf) ... def get_padding_type(self): @@ -962,7 +339,7 @@ def get_padding_type(self): # return 1 return 0 - def convert_index(self, expr, buffer): + def convert_index(self, expr): if len(expr.free_symbols) != 1: raise NotImplementedError("Not supporting this view operation...!") @@ -971,6 +348,7 @@ def convert_index(self, expr, buffer): expr_str = str(expr) if isinstance(expr, ModularIndexing): + dim = list(expr.args[0].free_symbols)[0] replace_str = f"({expr.args[0]} floordiv {expr.args[1]}) mod {expr.args[2]}" expr_str = re.sub(r"ModularIndexing\([^)]*\)", replace_str, expr_str) elif "//" in expr_str: @@ -978,17 +356,82 @@ def convert_index(self, expr, buffer): else: raise NotImplementedError("What is this case?") - indices = [expr.args[0]] - args = ", ".join(map(str, indices)) - map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args}) -> ({expr_str})>") - args = ", ".join([f"%{i}" for i in indices]) - index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})") + first_arg = expr.args[0] + if len(first_arg.free_symbols) != 1: + raise NotImplementedError("What is this case?") + + # Create affine.apply operation + indices = [list(first_arg.free_symbols)[0]] + with self.override_buffer_cse(buffer=self.global_vars, cse=self.map_cse): + map_var = ops.affine_map(indices, expr_str) + index = ops.affine_apply(map_var, indices) return index - def parse_indices(self, expr, buffer=None, comments="", indirect_dims=[]) -> common.CSEVariable: - if buffer is None: - buffer = self.applys + def _convert_sympy_to_mlir_expr(self, expr, sorted_args): + """ + Convert sympy expression to MLIR affine map expression by replacing index variables. + """ + indices = [] + + for arg in sorted_args: + if arg.is_Mul and arg.args[0].is_number: + target_arg = arg.args[1] + elif not arg.is_number: + target_arg = arg + else: + continue + new_arg = sympy.Symbol(str(self.convert_index(target_arg))) + expr = expr.replace(target_arg, new_arg) + indices.append(str(new_arg)) + + # Convert ModularIndexing and FloorDiv to sympy expressions + # ModularIndexing(x, y, z) means (x // y) % z -> Mod(FloorDiv(x, y), z) + # FloorDiv(x, y) means x // y -> will be converted to floordiv in string representation + # Use preorder_traversal to find all instances + replacements = {} + for sub in sympy.preorder_traversal(expr): + if isinstance(sub, ModularIndexing): + # Convert ModularIndexing to Mod(FloorDiv(...), ...) + if sub.args[1] != 1: + floor_div = FloorDiv(sub.args[0], sub.args[1]) + else: + floor_div = sub.args[0] + mod_expr = sympy.Mod(floor_div, sub.args[2]) + replacements[sub] = mod_expr + elif isinstance(sub, FloorDiv): + # Keep FloorDiv as is, will be handled in custom string conversion + # We need to mark it for special handling + pass + + # Apply replacements + for old_expr, new_expr in replacements.items(): + expr = expr.subs(old_expr, new_expr) + + # Custom string conversion for MLIR affine expressions + def mlir_str(expr): + """Convert sympy expression to MLIR affine expression string""" + if isinstance(expr, FloorDiv): + return f"({mlir_str(expr.args[0])} floordiv {mlir_str(expr.args[1])})" + elif isinstance(expr, sympy.Mod): + return f"({mlir_str(expr.args[0])} mod {mlir_str(expr.args[1])})" + elif isinstance(expr, sympy.Add): + terms = [mlir_str(term) for term in expr.args] + return " + ".join(terms) + elif isinstance(expr, sympy.Mul): + factors = [mlir_str(factor) for factor in expr.args] + return " * ".join(factors) + elif isinstance(expr, sympy.Symbol): + return str(expr) + elif expr.is_number: + return str(expr) + else: + # Fallback to string representation + return str(expr) + + expr_str = mlir_str(expr) + return expr_str, indices + def parse_indices(self, expr, comments="", indices=None, indirect_dims=[]) -> common.CSEVariable: # Constant case if expr.is_number and len(indirect_dims) == 0: return self.get_const_cse(int(expr)) @@ -1004,34 +447,19 @@ def parse_indices(self, expr, buffer=None, comments="", indirect_dims=[]) -> com # Sort index variable.. ex) (%index1, %index0) args_dict = {term: list(term.free_symbols)[0] for term in args if term.free_symbols} sorted_args = sorted(args_dict.keys(), key=lambda term: str(args_dict[term])) - indices = [] - for arg in sorted_args: - if arg.is_Mul and arg.args[0].is_number: - new_arg = sympy.Symbol(str(self.convert_index(arg.args[1], buffer))) - expr = expr.replace(arg.args[1], new_arg) - indices.append(str(new_arg)) - elif not arg.is_number: - new_arg = sympy.Symbol(str(self.convert_index(arg, buffer))) - expr = expr.replace(arg, new_arg) - indices.append(str(new_arg)) - # Extract index var + # Convert sympy expression to affine map expression + expr_str, indices = self._convert_sympy_to_mlir_expr(expr, sorted_args) indirect_args = [f"%{i}" for i in indirect_dims] - if len(indirect_args): - comments = "{indirect_access} " + comments # Add indirect access attribute - expr_str = str(expr) - if "//" in expr_str: - expr_str = expr_str.replace("//", " floordiv ") - args = ", ".join(map(str, indices)) - map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args})[{','.join(indirect_dims)}] -> ({expr_str})>") - args = ", ".join([f"%{i}" for i in indices]) - index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})[{','.join(indirect_args)}] {comments}") + # Create affine.apply operation + with self.override_buffer_cse(buffer=self.global_vars, cse=self.map_cse): + map_var = ops.affine_map(indices, expr_str, symbol_names=indirect_dims) + + index = ops.affine_apply(map_var, indices, indirect_dims=indirect_args, comment=comments) return index - def parse_index_list(self, expr_list:list, buffer=None, offset=sympy.Number(0)) -> common.CSEVariable: - if buffer is None: - buffer = self.applys - zero_var = self.get_const_cse(0) + def parse_index_list(self, expr_list:list, offset=sympy.Number(0)) -> common.CSEVariable: + """ Need to override buffer and cse to use this function. """ expr_list = [arg for arg in expr_list] dim_list = [f"d{i}" for i in range(len(expr_list))] @@ -1046,11 +474,16 @@ def parse_index_list(self, expr_list:list, buffer=None, offset=sympy.Number(0)) new_expr_list = [0] * len(expr_list) for idx, arg in enumerate(expr_list): if arg.is_Mul and arg.args[0].is_number: - new_arg = sympy.Symbol(str(self.convert_index(arg.args[1], buffer))) + new_arg = sympy.Symbol(str(self.convert_index(arg.args[1]))) new_expr_list[idx] = arg.subs(arg.args[1], dim_list[idx]) indices.append(str(new_arg)) elif not arg.is_number: - new_arg = sympy.Symbol(str(self.convert_index(arg, buffer))) + try: + new_arg = sympy.Symbol(str(self.convert_index(arg))) + #not implemented case + except NotImplementedError: + print(f"Not implemented case: {arg}") + raise NotImplementedError(f"Not implemented case: {arg}") new_expr_list[idx] = new_arg.subs(new_arg, dim_list[idx]) indices.append(str(new_arg)) else: @@ -1060,15 +493,14 @@ def parse_index_list(self, expr_list:list, buffer=None, offset=sympy.Number(0)) indices.append(str(new_arg)) # Extract index var + # Create affine.apply operation expr_str = str(sum(new_expr_list) + offset) - args = ", ".join(map(str, dim_list)) - map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args})[] -> ({expr_str})>") - args = ", ".join([f"%{i}" for i in indices]) - index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})[]") + with self.override_buffer_cse(buffer=self.global_vars, cse=self.map_cse): + map_var = ops.affine_map(dim_list, expr_str) + index = ops.affine_apply(map_var, indices) return index def load(self, name: str, index: sympy.Expr): - index = self.rename_indexing(index) index, comptute_depedency = self.convert_indirect_indexing(index) padding = self.get_padding_type() @@ -1095,46 +527,48 @@ def load(self, name: str, index: sympy.Expr): tile_numel_per_lane = local_tile_desc.get_numel_per_lane() tile_shape = local_tile_desc.get_mlir_shape(mlir_dtype) tile_stride = local_tile_desc.get_tile_stride() - # Compute vector unit size vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() # Define scratch pad buffer sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) + compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) # MVIN Encoding - attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding={padding}}}" + attribute = mlir_common.format_dma_op_attributes(dram_stride, tile_stride, int(padding)) code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, dram_shape, tile_shape, attribute) self.cse.generate(dma_buffer, code, assignment = False) # FIXME: assignment = False does not support caching if not comptute_depedency: - compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) # Generate vector load instruction - if compute_vec_size > 1: - operation = "affine.vector_load" - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" - else: - operation = "affine.load" - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}" - - out = self.cse.generate(load_buffer, line) - self.register_var_info(out, [compute_vec_size, mlir_dtype]) - self.spad_buffer_dict[str(out)] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape] - return out + with self.override_buffer_cse(buffer=load_buffer): + out = ops._load(compute_vec_size, mlir_dtype, sram_var, compute_index_var, tile_shape) else: + # FIXME. Any good idea? out = sram_var self.register_var_info(out, [compute_vec_size, mlir_dtype]) - self.spad_buffer_dict[str(out)] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape] - return out + self.spad_buffer_dict[str(out)] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape] + return out - def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): - index = self.rename_indexing(index) - dram_var = self.kernel_group.args.output(name) + def store(self, name: str, index: sympy.Expr, value, mode=None, *args, **kwargs): dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + # Handle scatter store + if "tmp" in str(index): + # Convert the output buffer type to the inplace buffer + arg_name = V.graph.scheduler.mutation_real_name.get(name, name) + if arg_name not in self.kernel_group.args.inplace_buffers: + self.kernel_group.args.make_inplace(arg_name, arg_name) + + if mode == "atomic_add": + loaded_value = ops.load(name, index) + value = ops.add(loaded_value, value) + index, _ = self.convert_indirect_indexing(index) + dram_var = self.kernel_group.args.output(name) + # Prepare dma instruction local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index) vlane_split_axis = local_tile_desc.vmap.vlane_split_axis @@ -1148,9 +582,6 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() require_store = True - if compute_vec_size < self.var_info[value][0]: - value = self.cse.generate(self.stores, f"vector.extract_strided_slice %{value} {{offsets = [0], sizes = [{compute_vec_size}], strides = [1]}}: vector<{self.var_info[value][0]}x{self.var_info[value][1]}> to {vshape}") - self.register_var_info(value, [compute_vec_size, mlir_dtype]) if str(value) in self.spad_buffer_dict: # Todo. If tile_size is not same (i.e., view operation), we can't apply peephole optimization easily @@ -1161,23 +592,22 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) # Generate vector store instruction - store_size, operand_type = self.var_info[value] + _, operand_type = self.var_info[value] if mlir_dtype != operand_type: - value = ops.custom_cast(value, mlir_dtype, var_info=self.var_info) + value = ops.to_dtype(value, mlir_dtype) - if compute_vec_size > 1 and store_size > 1: - operation = "affine.vector_store" - line = f"{operation} %{value}, %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" - else: - operation = "affine.store" - line = f"{operation} %{value}, %{sram_var}[{compute_index_var}] : {tile_shape}" - self.stores.writeline(common.DeferredLine(name, line)) # TODO: Should be changed to self.compute? + if compute_vec_size < self.var_info[value][0]: + with self.override_buffer_cse(buffer=self.stores): + value = ops.extract_strided_slice(value, compute_vec_size) + + with self.override_buffer_cse(buffer=self.stores): + ops._store(value, sram_var, compute_index_var, tile_shape, buffer_name=name) else: sram_var = self.spad_buffer_dict[str(value)][0] sram_index_var = self.spad_buffer_dict[str(value)][3] # Generate DMA instruction - attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" + attribute = mlir_common.format_dma_op_attributes(dram_stride, tile_stride, 0) code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, dram_shape, tile_shape, attribute) self.dma_stores.writeline(common.DeferredLine(name, code)) @@ -1206,10 +636,12 @@ def reduction(self, dtype, src_dtype, reduction_type, value): vec_len = self.kernel_group.tile_desc.get_compute_vec_size() reduced_shape = self.kernel_group.tile_desc.get_mlir_vshape(type_name) + + # Prepare reduction init - init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}") - init_vec = init if vec_len == 1 else self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {type_name} to {reduced_shape}") - self.register_var_info(init_vec, [vec_len, type_name]) + with self.override_buffer_cse(cse=self.const_cse, buffer=self.const_buffer): + init = self.get_const_cse(reduction_init(reduction_type, dtype), type_name) + init_vec = init if vec_len == 1 else ops.broadcast(init, vec_len) acc_var_list = [] iter_var_list = [] @@ -1239,192 +671,160 @@ def reduction(self, dtype, src_dtype, reduction_type, value): _, mask_var = self.get_mask() if mask_var is not None: value = ops.where(mask_var, value, init_vec) + result = reduction_partial_combine_vec(reduction_type, value, body_iter_arg) + result = ops.to_dtype(result, type_name) + self.compute_body_loop.reduction_vars[body_acc] = (reduction_type, body_iter_arg, iter_var_list[-1], reduced_shape) self.compute_body_loop.affine_yield[result] = reduced_shape - # Register affine yield var for reduction_depth, acc in enumerate(acc_var_list[1:]): self.affine_yield[acc] = reduced_shape, reduction_depth # Final reduction - acc = acc_var_list[0] # Set outermost acc var reduction_size = self.kernel_group.tile_desc.get_numel_per_lane() // self.kernel_group.tile_desc.get_reduction_numel() + acc = acc_var_list[0] # Set outermost acc var + self.register_var_info(acc, [reduction_size, type_name]) assert(vec_len % reduction_size==0) - if vec_len > reduction_size: - init = self.const_cse.generate(self.reductions_suffix, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}") - if reduction_size == 1: - final_reduced_shape = f"{type_name}" - out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(reduction_type, acc, init, axis=0, shape=reduced_shape, reduced_shape=final_reduced_shape)) - else: - final_reduced_shape = f"vector<{reduction_size}x{type_name}>" - init_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{init} : {type_name} to {final_reduced_shape}") - new_vshape= f"vector<{vec_len//reduction_size}x{reduction_size}x{type_name}>" - value = self.cse.generate(self.reductions_suffix, f"vector.shape_cast %{acc} : {reduced_shape} to {new_vshape}") - out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(reduction_type, value, init_vec, axis=0, shape=new_vshape, reduced_shape=final_reduced_shape)) - acc = out - - # reigster reduction output - var_info = [reduction_size, mlir_common.DTYPE_TO_MLIR[dtype]] - self.register_var_info(acc, var_info) + + # Prepare init value + init = self.get_const_cse(reduction_init(reduction_type, dtype), type_name) + if reduction_size != 1: + with self.override_buffer_cse(buffer=self.reductions_suffix): + init = ops.broadcast(init, reduction_size) + + # Final reduction codegen + with self.override_buffer_cse(buffer=self.reductions_suffix): + if vec_len > reduction_size: + acc = ops.multi_reduction(acc, init, vec_len, reduction_size, reduced_shape, reduction_type, type_name) return acc def store_reduction(self, name, index, value): - # Note: Change cse temporaily # Store reduction can't share cached value stored in cse, # since it is not innermost loop body. - tmp_cse = self.cse - tmp_apply_cse = self.apply_cse - self.cse = self.reduction_cse - self.apply_cse = self.reduction_cse - dram_var = self.kernel_group.args.output(name) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] - index = self.rename_indexing(index) - - # Tile is always reuduced in inner loop - local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index, broadcast=False, store_reduction=True, buffer=self.reductions_suffix) - vlane_split_axis = local_tile_desc.vmap.vlane_split_axis - vlane_stride = local_tile_desc.vmap.vlane_stride - - dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) - tile_shape = local_tile_desc.get_mlir_shape(mlir_dtype) - tile_stride = local_tile_desc.get_tile_stride() - compute_vec_size = self.kernel_group.tile_desc.get_numel_per_lane() // self.kernel_group.tile_desc.get_reduction_numel() - if compute_vec_size == 1: - vshape = f"{mlir_dtype}" - else: - vshape = f"vector<{compute_vec_size}x{mlir_dtype}>" - sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) - if self.welford_reduce_out is not None: - sum, sqr_sum, _ = self.welford_reduce_out - # mean - reduction_numel = reduce(mul, self.ranges[self.reduction_depth:], 1) - divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(reduction_numel)} : f32") - if compute_vec_size > 1: - divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to vector<{self.var_info[sum][0]}x{mlir_dtype}>") - else: - divider_vec = divider - mean = self.cse.generate(self.reductions_suffix, f"arith.divf %{sum}, %{divider_vec} : {vshape}") - - # m2 = (E(X^2) - E(X)^2) * N - sqr_mean = self.cse.generate(self.reductions_suffix, f"arith.divf %{sqr_sum}, %{divider_vec} : {vshape}") - mean_sqr = self.cse.generate(self.reductions_suffix, f"arith.mulf %{mean}, %{mean} : {vshape}") - variance = self.cse.generate(self.reductions_suffix, f"arith.subf %{sqr_mean}, %{mean_sqr} : {vshape}") - m2 = self.cse.generate(self.reductions_suffix, f"arith.mulf %{variance}, %{divider_vec} : {vshape}") - if self.current_node.node.origin_node: # FIXME: This is a temporary solution - value = mean - else: - value = m2 - - # Select src type - if compute_vec_size == 1: - operation = "affine.store" - line = f"{operation} %{value}, %{sram_var}[{sram_index_var}] : {tile_shape}" - else: - operation = "affine.vector_store" - line = f"{operation} %{value}, %{sram_var}[{sram_index_var}] : {tile_shape}, {vshape}" - self.reductions_suffix.writeline(common.DeferredLine(name, line)) - # MVOUT Encoding - # Generate DMA instruction - attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" - code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, attribute) - self.reductions_suffix.writeline(common.DeferredLine(name, code)) + with self.override_buffer_cse(cse=self.reduction_cse): + # Tile is always reuduced in inner loop + local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index, broadcast=False, store_reduction=True, buffer=self.reductions_suffix) + vlane_split_axis = local_tile_desc.vmap.vlane_split_axis + vlane_stride = local_tile_desc.vmap.vlane_stride - # Restore origin cse - self.cse = tmp_cse - self.apply_cse = tmp_apply_cse + dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) + tile_shape = local_tile_desc.get_mlir_shape(mlir_dtype) + tile_stride = local_tile_desc.get_tile_stride() - def indirect_indexing(self, index_var, size, check=True): + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) + with self.override_buffer_cse(buffer=self.reductions_suffix): + if self.welford_reduce_out is not None: + # Calc var and mean + sum, sqr_sum, _ = self.welford_reduce_out + reduction_numel = reduce(mul, self.ranges[self.reduction_depth:], 1) + divider = self.get_const_cse(float(reduction_numel), "f32") + mean = ops.truediv(sum, divider) + sqr_mean = ops.truediv(sqr_sum, divider) + mean_sqr = ops.mul(mean, mean) + variance = ops.sub(sqr_mean, mean_sqr) + m2 = ops.mul(variance, divider) + if self.current_node.node.origin_node: # FIXME: This is a temporary solution + value = mean + else: + value = m2 + # Store value to scratch pad + ops._store(value, sram_var, sram_index_var, tile_shape, buffer_name=name) + + # Generate DMA instruction + attribute = mlir_common.format_dma_op_attributes(dram_stride, tile_stride, 0) + code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, attribute) + self.reductions_suffix.writeline(common.DeferredLine(name, code)) + + def indirect_indexing(self, index_var, size, check=True, wrap_neg=True): return str(index_var) def _index_expr(self, tile_desc, renamed_expression, index, base_vector_index): # In case of index expr, dimension size should be divisible by tile size if not self.kernel_group.tile_desc.is_dim_dividable(self.ranges): new_tile_size = self.kernel_group.tile_desc.adjust_tile_to_divisible(self.ranges) + prior_tile_size, prior_ranges = self.kernel_group.tile_desc.get_tile_size(), self.ranges self.kernel_group.tile_desc.set_tile_size(new_tile_size) self.reset("recompile") - raise mlir_common.RecompileSignal(f"Index access (tile size {self.kernel_group.tile_desc.get_tile_size()} is not divisible by {self.ranges})") + raise mlir_common.RecompileSignal(f"Index access (tile size {prior_tile_size} is not divisible by {prior_ranges})") - tile_size = tile_desc.get_tile_size_per_lane() + tile_size_per_lane = tile_desc.get_tile_size_per_lane() compute_vec_size = tile_desc.get_compute_vec_size() strides = tile_desc.get_tile_stride_per_lane() # Create vector index - compute_vec = self.cse.generate(self.compute, f"vector.broadcast %{self.compute_idx} : index to vector<{compute_vec_size}xindex>") - self.register_var_info(compute_vec, [compute_vec_size, "index"]) + compute_vec = ops.broadcast(self.compute_idx, compute_vec_size) vector_index = ops.add(base_vector_index, compute_vec) # Create tile_dim index dim_list = [] - for idx in range(len(tile_size)): - div_coeff = self.get_const_cse(strides[idx], "index") - mod_coeff = self.get_const_cse(tile_size[idx], "index") - div_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{div_coeff} : index to vector<{compute_vec_size}xindex>") - mod_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{mod_coeff} : index to vector<{compute_vec_size}xindex>") - self.register_var_info(div_vec, [compute_vec_size, "index"]) - self.register_var_info(mod_vec, [compute_vec_size, "index"]) - dim = ops.modular(ops.div(vector_index, div_vec), mod_vec) - if idx == tile_desc.vmap.vlane_split_axis: # Need to add vector lane offset - offset = tile_desc.vmap.vlane_stride #* strides[idx] - outer_sz = tile_size[idx] // tile_desc.vmap.vlane_stride - - nr_vector_lane = self.get_const_cse(self.vector_lane, "index") - nr_vector_lane_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{nr_vector_lane} : index to vector<{compute_vec_size}xindex>") - self.register_var_info(nr_vector_lane_vec, [compute_vec_size, "index"]) - + for idx in range(len(tile_size_per_lane)): + # Prepare initial values + offset = tile_desc.vmap.vlane_stride #* strides[idx] + outer_sz = tile_desc.get_tile_size()[idx] // tile_desc.vmap.vlane_stride + with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse): + div_coeff = self.get_const_cse(strides[idx], "index") + mod_coeff = self.get_const_cse(tile_size_per_lane[idx], "index") vlane_stride_coeff = self.get_const_cse(tile_desc.vmap.vlane_stride, "index") vlane_outer_coeff = self.get_const_cse(outer_sz, "index") - vlane_stride_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{vlane_stride_coeff} : index to vector<{compute_vec_size}xindex>") - vlane_outer_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{vlane_outer_coeff} : index to vector<{compute_vec_size}xindex>") - self.register_var_info(vlane_stride_vec, [compute_vec_size, "index"]) - self.register_var_info(vlane_outer_vec, [compute_vec_size, "index"]) - stride_dim = ops.modular(dim, vlane_stride_vec) - outer_dim = ops.modular(ops.div(dim, vlane_stride_vec), vlane_outer_vec) + nr_vector_lane = self.get_const_cse(self.vector_lane, "index") + vlane_coeff = self.get_const_cse(0, "i64") - dim = ops.add(stride_dim, ops.mul(outer_dim, nr_vector_lane_vec)) + div_vec = ops.broadcast(div_coeff, compute_vec_size) + mod_vec = ops.broadcast(mod_coeff, compute_vec_size) + nr_vector_lane_vec = ops.broadcast(nr_vector_lane, compute_vec_size) + vlane_stride_vec = ops.broadcast(vlane_stride_coeff, compute_vec_size) + vlane_outer_vec = ops.broadcast(vlane_outer_coeff, compute_vec_size) # Prepare vlane offset (vidx) - vlane_coeff = self.get_const_cse(0, "i64") vlane_vec_size = 4 - vlane_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{vlane_coeff} : i64 to vector<{vlane_vec_size}xi64>") - vlane_offset = self.const_cse.generate(self.const_buffer, f"arith.addi %{vlane_vec}, %{vlane_vec} {{ vlane_offset={offset} }} : vector<{vlane_vec_size}xi64> // vlane offset") - self.register_var_info(vlane_offset, [vlane_vec_size, "i64"]) - vlane_offset = ops.index_cast(vlane_offset, "index") - self.register_var_info(vlane_offset, [vlane_vec_size, "index"]) + vlane_vec = ops.broadcast(vlane_coeff, vlane_vec_size) + dim = ops.remainder(ops.truncdiv(vector_index, div_vec), mod_vec) + if idx == tile_desc.vmap.vlane_split_axis: # Need to add vector lane offset + stride_dim = ops.remainder(dim, vlane_stride_vec) + outer_dim = ops.remainder(ops.truncdiv(dim, vlane_stride_vec), vlane_outer_vec) + dim = ops.add(stride_dim, ops.mul(outer_dim, nr_vector_lane_vec)) + + with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse): + vlane_offset = ops.vlane_offset(vlane_vec, vlane_vec, attributes={"vlane_offset": offset}, comment="vlane offset") + if compute_vec_size < self.var_info[vlane_offset][0]: + vlane_offset = ops.extract_strided_slice(vlane_offset, compute_vec_size) + vlane_offset = ops.index_cast(vlane_offset, "index") dim = ops.add(dim, vlane_offset) dim_list.append(dim) indices = [str(i) for i in index.free_symbols] for idx in indices: i = int(idx[5:]) - index_vec = self.cse.generate(self.compute, f"vector.broadcast %{idx} : index to vector<{compute_vec_size}xindex>") - self.register_var_info(index_vec, [compute_vec_size, "index"]) + idx = self.itervar_cses[idx] + index_vec = ops.broadcast(idx, compute_vec_size) offset = ops.add(index_vec, dim_list[i]) dim_list[i] = offset arg_lists = [] for arg in renamed_expression.args: if isinstance(arg, sympy.Integer): - offset = self.get_const_cse(int(arg)) - offset_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{offset} : index to vector<{compute_vec_size}xindex>") - self.register_var_info(offset_vec, [compute_vec_size, "index"]) + with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse): + offset = self.get_const_cse(int(arg), "index") + offset_vec = ops.broadcast(offset, compute_vec_size) arg_lists.append(offset_vec) elif isinstance(arg, sympy.Mul): if isinstance(arg.args[0], sympy.Integer) and isinstance(arg.args[1], sympy.Symbol): - coeff = self.get_const_cse(int(arg.args[0])) - coeff_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{coeff} : index to vector<{compute_vec_size}xindex>") - self.register_var_info(coeff_vec, [compute_vec_size, "index"]) + with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse): + coeff = self.get_const_cse(int(arg.args[0]), "index") + coeff_vec = ops.broadcast(coeff, compute_vec_size) result = ops.mul(dim_list[int(str(arg.args[1])[1:])], coeff_vec) arg_lists.append(result) elif isinstance(arg.args[1], sympy.Integer) and isinstance(arg.args[0], sympy.Symbol): - coeff = self.get_const_cse(int(arg.args[1])) - coeff_vec = self.cse.generate(self.compute, f"vector.broadcast %{coeff} : index to vector<{compute_vec_size}xindex>") - self.register_var_info(coeff_vec, [compute_vec_size, "index"]) + with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse): + coeff = self.get_const_cse(int(arg.args[1]), "index") + coeff_vec = ops.broadcast(coeff, compute_vec_size) result = ops.mul(dim_list[int(str(arg.args[0])[1:])], coeff_vec) arg_lists.append(result) else: @@ -1458,7 +858,6 @@ def index_expr(self, index, dtype): tile_desc = base_tile_desc compute_vec_size = tile_desc.get_compute_vec_size() - tile_shape = f"memref<{compute_vec_size*self.vector_lane}xindex, 1>" vshape = f"vector<{compute_vec_size}xindex>" @@ -1474,18 +873,16 @@ def index_expr(self, index, dtype): # Initialize base vector if not self.base_vector_initialized: - init_iter = "iter" + init_iter = self.register_var_cse("init_iter", 1, "index") parallel_map = f"affine.parallel (%{init_iter}) = ({0}) to ({compute_vec_size}) {{ // Base vector initializer" self.spad_buffer.writeline(parallel_map) with self.spad_buffer.indent(): - self.spad_buffer.writeline(f"%init_vec = vector.broadcast %{init_iter} : index to vector<2xindex>") - self.spad_buffer.writeline(f"affine.vector_store %init_vec, %{sram_var}[%{init_iter}] : {tile_shape}, vector<2xindex>") + with self.override_buffer_cse(buffer=self.spad_buffer, cse=self.init_vec_cse): + init_vec = ops.broadcast(init_iter, 2) + ops._store(init_vec, sram_var, f"%{init_iter}", tile_shape) self.spad_buffer.writeline("}") self.base_vector_initialized = True - - line = f"affine.vector_load %{sram_var}[0] : {tile_shape}, {vshape}" - base_vector_index = self.cse.generate(self.compute, line) - self.register_var_info(base_vector_index, [compute_vec_size, "index"]) + base_vector_index = ops._load(compute_vec_size, "index", sram_var, "0", tile_shape) renamed_symbols = {symbol: "d"+str(symbol)[5:] for symbol in index.free_symbols} renamed_expression = index.subs(renamed_symbols) @@ -1575,15 +972,17 @@ def make_choices(self, nodes, kernel_name): # Try initial tile size self.reset(None) - src_code = super().codegen_nodes(nodes, kernel_name) + try: + src_code, meta_code = super().codegen_nodes(nodes, kernel_name) + except mlir_common.RecompileSignal: + continue current_tile_sz = tuple(self.kernel_group.tile_desc.get_tile_size()) search_space.add(current_tile_sz) - if extension_config.CONFIG_DEBUG_MODE: - print(f"[Auto-tune] Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}") + logger.debug(f"Auto-tune: Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}") self._prepare_simulator_headers(src_code) bench_runner = self.run_bench(nodes, kernel_name, src_code) - choices.append((bench_runner, src_code, current_tile_sz, self.kernel_group.tile_desc.vmap.vlane_stride)) + choices.append((bench_runner, src_code, meta_code, current_tile_sz, self.kernel_group.tile_desc.vmap.vlane_stride)) while prevent_infinite_loop < 10 and candidate_axes: for axis in list(candidate_axes): @@ -1598,14 +997,12 @@ def make_choices(self, nodes, kernel_name): # Try increase tile size for this axis try: self.kernel_group.tile_desc.scale_tile_dim(axis, prev_ranges[axis], 2) - except extension_codecache.TileSizeError as e: - # Failed to find proper tile size + self.reset(None) + src_code, meta_code = super().codegen_nodes(nodes, kernel_name) + except (extension_codecache.TileSizeError, mlir_common.RecompileSignal): candidate_axes.remove(axis) self.reset(None) continue - - self.reset(None) - src_code = super().codegen_nodes(nodes, kernel_name) current_tile_sz = tuple(self.kernel_group.tile_desc.get_tile_size()) # FIXME. How to intergrate this constraint to tile system? @@ -1622,11 +1019,10 @@ def make_choices(self, nodes, kernel_name): # Add this choice search_space.add(current_tile_sz) - if extension_config.CONFIG_DEBUG_MODE: - print(f"[Auto-tune] Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}") + logger.debug(f"Auto-tune: Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}") self._prepare_simulator_headers(src_code) bench_runner = self.run_bench(nodes, kernel_name, src_code) - choices.append((bench_runner, src_code, self.kernel_group.tile_desc.get_tile_size(), self.kernel_group.tile_desc.vmap.vlane_stride)) + choices.append((bench_runner, src_code, meta_code, self.kernel_group.tile_desc.get_tile_size(), self.kernel_group.tile_desc.vmap.vlane_stride)) prevent_infinite_loop += 1 self.kernel_group.tile_desc.prev_tail_threshold = prev_tail_threshold return choices @@ -1642,18 +1038,24 @@ def get_cycle(choice): return float("inf") return float("inf") # Exceeded maximum number of autotuning attempts choices = self.make_choices(*args) + if len(choices) == 0: # Can't autotune + return [None, None, None] + + # Get cycle time for each choice + # Show progress bar only when CONFIG_DEBUG_MODE is off + show_progress = not extension_config.CONFIG_DEBUG_MODE + with ProgressBar("[Auto-tune] Running benchmarks", silent_mode=not show_progress) if show_progress else contextlib.nullcontext(): + with ThreadPoolExecutor(max_workers=8) as executor: + results = list(executor.map(get_cycle, choices)) - if len(choices) == 0: # can't autotune - return [None, None] - with ThreadPoolExecutor(max_workers=8) as executor: - results = list(executor.map(get_cycle, choices)) - max_idx = results.index(min(results)) + min_idx = results.index(min(results)) if min(results) == float("inf"): raise RuntimeError("Failed to find optimal tile size...") - if extension_config.CONFIG_DEBUG_MODE: - self._log_autotune_result(choices[max_idx], results[max_idx]) - optimal_src_code, loop_size = choices[max_idx][1], choices[max_idx][-1] - return optimal_src_code, loop_size + + self._log_autotune_result(choices[min_idx], results[min_idx]) + + optimal_src_code, meta_code, loop_size = choices[min_idx][1], choices[min_idx][2], choices[min_idx][-1] + return optimal_src_code, meta_code, loop_size def run_bench(self, nodes, kernel_name, src_code): _, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs() @@ -1671,8 +1073,9 @@ def run_bench(self, nodes, kernel_name, src_code): "spad_info": self.spad_info, "vlen" : self.vlen, "arg_attributes" : arg_attributes, - "validate" : extension_config.pytorchsim_functional_mode, "autotune" : True, + "loop_size" : self.loop_size, + "origins" : {str(i) for node in nodes for i in node.node.origins}, }, source_code=src_code, ) @@ -1681,22 +1084,24 @@ def run_bench(self, nodes, kernel_name, src_code): return bmreq.make_run_fn(dummy_inputs, dummy_outputs) def _log_autotune_result(self, best_choice, best_cycle): - print( - f"[Auto-tune] Optimal tile size: {list(best_choice[2])}, " - f"vlane_stride: {best_choice[3]}, " + logger.debug( + f"Auto-tune: Optimal tile size: {list(best_choice[3])}, " + f"vlane_stride: {best_choice[4]}, " f"cycles: {best_cycle}" ) def codegen_nodes(self, nodes, kernel_name): - src_code = super().codegen_nodes(nodes, kernel_name) + src_code, meta_code = super().codegen_nodes(nodes, kernel_name) self._prepare_simulator_headers(src_code) if "autotune" in extension_config.codegen_mapping_strategy and extension_config.pytorchsim_timing_mode: - optimal_src_code = self.autotune(nodes, kernel_name)[0] + optimal_src_code, meta_code = self.autotune(nodes, kernel_name)[:2] if optimal_src_code is not None: - return optimal_src_code - return src_code + return optimal_src_code, meta_code + return src_code, meta_code def _prepare_simulator_headers(self, src_code): + from filelock import FileLock + write_path = extension_codecache.get_write_path(src_code) os.makedirs(write_path, exist_ok=True) @@ -1707,8 +1112,10 @@ def _prepare_simulator_headers(self, src_code): spad_section_end_symbol = ( f"int spad_section_end[0] __attribute__ ((section(\".spad\"), aligned({self.spad_info['spad_size']*self.vector_lane})));" ) - write_atomic(spike_write_path, self.header.getvalue() + spad_end_symbol + spad_section_end_symbol) - write_atomic(gem5_write_path, self.gem5_header.getvalue()) + lock = FileLock(extension_codecache.get_lock_path(write_path), timeout=extension_codecache.LOCK_TIMEOUT) + with lock: + write_atomic(spike_write_path, self.header.getvalue() + spad_end_symbol + spad_section_end_symbol) + write_atomic(gem5_write_path, self.gem5_header.getvalue()) def get_arg_info(self, name): arg_info = dict() @@ -1736,15 +1143,16 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe total_dims = [int(str(i)[5:]) for i in self.itervars] local_tile_desc = mlir_common.MLIRMultiDimTile([1], self.vector_lane) local_dims.sort() # Assume that smaller index is placed in the outer loop - indirect_dims = [f"{i}" for i in index.free_symbols if "tmp" in str(i)] - for indirect_dim in indirect_dims: - index = index.replace(sympy.Symbol(indirect_dim), 0) + indirect_syms = [s for s in index.free_symbols if "tmp" in s.name] + index = index.subs({s: 0 for s in indirect_syms}, simultaneous=True) + indirect_dims = [f"{i}" for i in indirect_syms] # Reduction can have two type of tile size if broadcast and (total_dims != local_dims or (self.reduction_depth!=len(total_dims) and total_dims[:self.reduction_depth] == local_dims)): local_dims = total_dims # Brodatcast tile shape - index_var = self.parse_indices(index, buffer=buffer, indirect_dims=indirect_dims) + with self.override_buffer_cse(buffer=buffer, cse=self.apply_cse): + index_var = self.parse_indices(index, indirect_dims=indirect_dims, comments=f"// store_reduction={store_reduction}") if kg_tile_desc.vmap.vlane_split_axis in local_dims: local_vlane_split_axis = local_dims.index(kg_tile_desc.vmap.vlane_split_axis) @@ -1814,27 +1222,38 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe for constraint in sorted_constraints[1:]: index = index.replace(constraint.original_expr, 0) - # Calculate dram stride + # Calculate dram stride in local tile-dim order. + # This keeps dram/sram stride rank aligned with tile rank. + local_dim_to_axis = {dim: axis for axis, dim in enumerate(local_dims)} dram_stride = [0] * local_tile_desc.get_nr_dim() if index.is_Symbol: dim_idx = int(str(index)[5:]) - dram_stride[dim_idx] = 1 + if dim_idx in local_dim_to_axis: + dram_stride[local_dim_to_axis[dim_idx]] = 1 elif index.is_Number: pass else: + dram_dict = defaultdict(list) + implicit_dim_divisors = defaultdict(lambda: sys.maxsize) # Assume that div will have high priority than mod for arg in index.as_ordered_terms(): coeff, dim = arg.as_coeff_mul() if len(dim) == 0: continue real_dim = list(dim[0].free_symbols)[0] - dram_dict[str(real_dim)].append(coeff) + if dim[0].has(ModularIndexing): + if dim[0].args[1] < implicit_dim_divisors[str(real_dim)]: + implicit_dim_divisors[str(real_dim)] = dim[0].args[1] + dram_dict[str(real_dim)] = [coeff] + else: + dram_dict[str(real_dim)].append(coeff) + # Add missing dims if not added max_dim = len(self.ranges) if not store_reduction else len(self.ranges) - 1 for i in range(max_dim): target_dim = f"index{i}" - if target_dim not in str(index): + if sympy.Symbol(target_dim) not in index.free_symbols: dram_dict[target_dim] = [0] sorted_keys = sorted(dram_dict.keys()) dram_stride = sum((dram_dict[key] for key in sorted_keys), []) @@ -1849,20 +1268,28 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe if not str(sub.args[0]).startswith("index"): continue dim_idx = int((str(sub.args[0])[5:])) + if dim_idx not in local_dim_to_axis: + continue + local_dim_idx = local_dim_to_axis[dim_idx] if int(self.kernel_group.tile_desc.get_tile_size()[dim_idx] % sub.args[1]) != 0: # In this case, need to recompile - original_size = self.kernel_group.tile_desc.get_tile_size()[dim_idx] - divisor = sub.args[1] + original_tile = self.kernel_group.tile_desc.get_tile_size() + original_size = original_tile[dim_idx] + divisor = sub.args[1] * self.kernel_group.tile_desc.vmap.vlane_stride new_size = ((original_size + divisor - 1) // divisor) * divisor new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size()) new_tile_sizes[dim_idx] = new_size self.kernel_group.tile_desc.set_tile_size(new_tile_sizes) self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True + # Can't use dim_idx as vlane_split_axis + if dim_idx == self.kernel_group.tile_desc.vmap.vlane_split_axis: + self.kernel_group.tile_desc.vmap.vlane_split_axis = (dim_idx + 1) % len(original_tile) + # Send recompile signal self.reset("recompile") raise mlir_common.RecompileSignal(f"Tile size {self.kernel_group.tile_desc.get_tile_size()[dim_idx]} is not divisible by {sub.args[1]}") - dim_divisor[dim_idx] = sub.args[1] + dim_divisor[local_dim_idx] = sub.args[1] # Update dram_stride, just insert 0 next to target dim offset = 0 @@ -1874,6 +1301,57 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe local_tile_desc.apply_divisor(dim_idx+offset, divisor, "split") offset = offset+1 + # Support ModularIndexing pattern + # This pattern can be used to broadcast ex) torch.cat([a,a]) + # ModularIndexing(x, y, z) means (x // y) % z + # tile_size must be: multiple of y (floorDiv divisor) and divisor of z (modular divisor) + if index.has(ModularIndexing): + for sub in sympy.preorder_traversal(index): + if isinstance(sub, ModularIndexing): + if not str(sub.args[0]).startswith("index"): + continue + dim_idx = int((str(list(sub.args[0].free_symbols)[0])[5:])) + floor_divisor = sub.args[1] # y: floorDiv divisor + mod_divisor = sub.args[2] # z: modular divisor + current_tile_size = self.kernel_group.tile_desc.get_tile_size()[dim_idx] + + # Check if tile_size is multiple of floorDiv divisor + if int(current_tile_size % floor_divisor) != 0: + original_tile = self.kernel_group.tile_desc.get_tile_size() + original_size = original_tile[dim_idx] + divisor = floor_divisor * self.kernel_group.tile_desc.vmap.vlane_stride + new_size = ((original_size + divisor - 1) // divisor) * divisor + new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size()) + new_tile_sizes[dim_idx] = new_size + self.kernel_group.tile_desc.set_tile_size(new_tile_sizes) + self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True + + self.reset("recompile") + raise mlir_common.RecompileSignal(f"Tile size {current_tile_size} is not a multiple of floorDiv divisor {floor_divisor} in ModularIndexing") + + # Check if tile_size is a divisor of modular divisor + if int((mod_divisor * floor_divisor) % current_tile_size) != 0: + original_tile = self.kernel_group.tile_desc.get_tile_size() + original_size = original_tile[dim_idx] + # Find the largest divisor of mod_divisor that is <= original_size + # and is a multiple of floor_divisor + new_size = original_size + while new_size > 0: + if mod_divisor % new_size == 0 and new_size % floor_divisor == 0: + break + new_size -= floor_divisor + + if new_size <= 0: + new_size = mod_divisor * floor_divisor + + new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size()) + new_tile_sizes[dim_idx] = new_size + self.kernel_group.tile_desc.set_tile_size(new_tile_sizes) + self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True + + self.reset("recompile") + raise mlir_common.RecompileSignal(f"Tile size {current_tile_size} is not a divisor of modular divisor {mod_divisor} in ModularIndexing") + # FIXME. It will be nice to modify node instead of this exception handling... if len(self.itervars) == 1 and self.reduction_depth == 0: # In case of reduction loop only case, we will add dummy loop so shift it once @@ -1957,15 +1435,19 @@ def get_scratchpad_buffer(self, dtype, dram_name, tile_desc, raw_index, buffer=N return sram_var, sram_index_var def get_const_cse(self, value, dtype="index") -> common.CSEVariable: + # Why not use ops.constant? Because there are some cases that can't use ops (e.g., def_dma_op) # Type convert - if dtype[0] == "f": + if value in ["inf", "-inf", "nan"]: + value = f"0x{mlir_common.MLIR_INF[value][dtype]:x}" + elif dtype[0] == "f": value = float(value) else: value = int(value) - - if value not in self.consts: - self.consts[str(value)+dtype] = self.const_cse.generate(self.const_buffer, f"arith.constant {value} : {dtype}") - return self.consts[str(value)+dtype] + key = str(value)+dtype + if key not in self.consts: + self.consts[key] = self.const_cse.generate(self.const_buffer, f"arith.constant {value} : {dtype}") + self.register_var_info(self.consts[key], [1, dtype]) + return self.consts[key] def get_tag_cse(self, value=None, shape="memref<1xi32>"): if value is None: @@ -1979,16 +1461,16 @@ def get_mask(self): if self.compute_body_loop.size % self.compute_body_loop.step == 0: return None, None compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() - index_shape = f"vector<{self.compute_body_loop.step}xindex>" mask_shape = f"vector<{compute_vec_size}xi1>" - upper_bound = self.get_const_cse(self.compute_body_loop.size) - step_vec = self.const_cse.generate(self.const_buffer, f"vector.step : {index_shape}") + with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse): + upper_bound = ops.constant(self.compute_body_loop.size, "index") + step_vec = ops.step(self.compute_body_loop.step, "index") - gap = self.mask_cse.generate(self.masks, f"arith.subi %{upper_bound}, %{self.compute_idx} : index") - gap_vec = self.mask_cse.generate(self.masks, f"vector.broadcast %{gap} : index to {index_shape}") - mask_var = self.mask_cse.generate(self.masks, f"arith.cmpi ult, %{step_vec}, %{gap_vec} : {index_shape}") - self.register_var_info(mask_var, [compute_vec_size, "i1"]) + with self.override_buffer_cse(buffer=self.masks, cse=self.mask_cse): + gap = ops.sub(upper_bound, self.compute_idx) + gap_vec = ops.broadcast(gap, self.compute_body_loop.step) + mask_var = ops.lt(step_vec, gap_vec) return mask_shape, mask_var def convert_indirect_indexing(self, index :sympy.Expr): @@ -2007,14 +1489,8 @@ def convert_indirect_indexing(self, index :sympy.Expr): indirect_dims.sort() first_dim = indirect_dims[0] spad_vars = dict() - old_compute, old_dma_lods, old_dma_stores = self.compute, self.dma_loads, self.dma_stores compute_dependecy = any([target_dim not in self.spad_buffer_dict for target_dim in indirect_dims]) - if compute_dependecy: - self.compute = old_dma_stores - target_dma_buffers = self.dma_stores - else: - self.compute = old_dma_lods - target_dma_buffers = self.dma_loads + target_dma_buffers = self.dma_stores if compute_dependecy else self.dma_loads # Load indirect operands for target_dim in indirect_dims: @@ -2028,62 +1504,48 @@ def convert_indirect_indexing(self, index :sympy.Expr): local_tile_desc = self.kernel_group.tile_desc tile_numel_per_lane = local_tile_desc.get_numel_per_lane() tile_shape = local_tile_desc.get_mlir_shape(var_info[1]) + tile_vec = local_tile_desc.get_compute_vec_size() vshape = f"vector<{var_info[0]}x{var_info[1]}>" sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, target_dim, local_tile_desc, target_dim) self.spad_buffer_dict[target_dim] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape] # Store the indirect index variable - opeartion = "affine.vector_store" + target_var = self.cse.varname_map[target_dim] compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) - line = f"{opeartion} %{target_dim}, %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" - self.stores.writeline(line) + with self.override_buffer_cse(buffer=self.stores): + ops._store(target_var, sram_var, compute_index_var, tile_shape) mlir_dtype = vshape.split("x")[1][:-1] - vshape = f"vector<{tile_numel_per_lane}x{mlir_dtype}>" # FIXME. Maybe require fine grain compute... - if tile_numel_per_lane > 1: - operation = "affine.vector_load" - line = f"{operation} %{sram_var}[{sram_index_var}] : {tile_shape}, {vshape} // For indirect access" - else: - operation = "affine.load" - line = f"{operation} %{sram_var}[{sram_index_var}] : {tile_shape} // For indirect access" - out = self.cse.generate(target_dma_buffers, line) - self.register_var_info(out, [tile_numel_per_lane, mlir_dtype]) - spad_vars[target_dim] = out - - # Apply stride - for arg in index.args: - if "tmp" not in str(arg): - continue - if arg.is_Mul and arg.args[0].is_number: - coeff_dtype = self.var_info[spad_vars[str(arg.args[1])]][1] - coeff = ops.constant(int(arg.args[0]), coeff_dtype) - spad_vars[str(arg.args[1])] = ops.mul(spad_vars[str(arg.args[1])], coeff) - index = index.replace(arg, 0) - - # Sum - for dim, var in spad_vars.items(): - if dim == first_dim: - continue - spad_vars[first_dim] = ops.add(spad_vars[first_dim], var) + with self.override_buffer_cse(buffer=target_dma_buffers): + out = ops._load(tile_numel_per_lane, mlir_dtype, sram_var, sram_index_var, tile_shape) + spad_vars[target_dim] = out + + with self.override_buffer_cse(buffer=target_dma_buffers): + # Apply stride + for arg in index.args: + if "tmp" not in str(arg): + continue + if arg.is_Mul and arg.args[0].is_number: + coeff_dtype = self.var_info[spad_vars[str(arg.args[1])]][1] + coeff = self.get_const_cse(int(arg.args[0]), coeff_dtype) + spad_vars[str(arg.args[1])] = ops.mul(spad_vars[str(arg.args[1])], coeff) + index = index.replace(arg, 0) + + # Sum + for dim, var in spad_vars.items(): + if dim == first_dim: + continue + spad_vars[first_dim] = ops.add(spad_vars[first_dim], var) # Store index var sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[first_dim] mlir_dtype = vshape.split("x")[1][:-1] - vshape = f"vector<{tile_numel_per_lane}x{mlir_dtype}>" # FIXME. Maybe require fine grain compute... - if tile_numel_per_lane > 1: - operation = "affine.vector_store" - line = f"{operation} %{spad_vars[first_dim]}, %{sram_var}[{sram_index_var}] : {tile_shape}, {vshape}" - else: - operation = "affine.store" - line = f"{operation} %{spad_vars[first_dim]}, %{sram_var}[{sram_index_var}] : {tile_shape}" - out = self.cse.generate(target_dma_buffers, line, assignment=False) + with self.override_buffer_cse(buffer=target_dma_buffers): + ops._store(spad_vars[first_dim], sram_var, sram_index_var, tile_shape) # FIXME. Maybe require fine grain compute... # Conversion mlir_dtype = self.var_info[spad_vars[first_dim]][1] - line = f"affine.load %{sram_var}[{sram_index_var}] : {tile_shape}" - out = self.cse.generate(target_dma_buffers, line) - if mlir_dtype != "index": - line = f"arith.index_cast %{out} : {mlir_dtype} to {'index'}" - out = self.cse.generate(target_dma_buffers, line) - self.register_var_info(out, [1, "index", [1]]) - self.compute, self.dma_loads, self.dma_stores = old_compute, old_dma_lods, old_dma_stores + with self.override_buffer_cse(buffer=target_dma_buffers): + out = ops._load(1, mlir_dtype, sram_var, sram_index_var, tile_shape) + if mlir_dtype != "index": + out = ops.index_cast(out, "index") return index + sympy.Symbol(str(out)), compute_dependecy diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 4d33eea4..5cde19eb 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -1,18 +1,22 @@ import dataclasses import math +import contextvars +from contextlib import contextmanager from dataclasses import dataclass -from typing import Dict -from typing import List +from typing import Dict, Iterable, List, Optional, Sequence, Union from collections import defaultdict from functools import reduce from operator import mul import torch + +from PyTorchSimFrontend import extension_config from torch._inductor.codegen import common from torch._inductor.codegen import cpp from torch._inductor.virtualized import V from torch._inductor.ir import MultiOutputLayout from torch._inductor.dependencies import MemoryDep, StarDep, WeakDep -from torch.utils._sympy.functions import ModularIndexing, FloorDiv, Mod +from torch._inductor.codegen.wrapper import KernelDefinitionLine +from torch.utils._sympy.functions import ModularIndexing, FloorDiv, Mod, Identity import sympy import contextlib @@ -20,18 +24,20 @@ import sympy -import torch.fx from torch.utils._sympy.value_ranges import ValueRanges from torch._inductor.utils import ( - free_symbol_startswith, get_sympy_Expr_dtype, IndentedBuffer, sympy_subs, - sympy_symbol, unique, ) -from PyTorchSimFrontend import extension_config from PyTorchSimFrontend import extension_codecache + +from PyTorchSimFrontend.extension_utils import ( + free_symbol_startswith, + sympy_symbol +) + schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") DTYPE_TO_MLIR = { @@ -61,14 +67,14 @@ DTYPE_TO_C = { torch.float32: "float", torch.float64: "double", - torch.float16: "half", + torch.float16: "uint16_t", torch.int64: "int64_t", torch.int32: "int32_t", torch.int16: "int16_t", torch.int8: "int8_t", torch.uint8: "uint8_t", torch.bool: "uint8_t", - torch.bfloat16: "bfloat16", + torch.bfloat16: "uint16_t", } MLIR_TO_BIT = { @@ -84,6 +90,12 @@ "index": 64 } +def get_dtype_nbytes(dtype): + mlir_dtype = DTYPE_TO_MLIR.get(dtype) + if mlir_dtype is None or mlir_dtype not in MLIR_TO_BIT: + raise NotImplementedError(f"Unsupported dtype for precision calculation: {dtype}") + return MLIR_TO_BIT[mlir_dtype] // 8 + DTYPE_LOWP_FP = [ torch.bfloat16, torch.float16, @@ -91,19 +103,43 @@ MLIR_INF = { "inf" : { + "f16" : 0x7C00, "f32" : 0x7F800000, "f64" : 0x7FF0000000000000 }, "-inf" : { + "f16" : 0xFC00, "f32" : 0xFF800000, "f64" : 0xFFF0000000000000 }, "nan" : { + "f16" : 0x7C00, "f32" : 0x7FC00000, "f64" : 0x7FF8000000000000 } } +def format_dma_op_attributes( + dram_stride: Sequence, + sram_stride: Sequence, + padding: int = 0, + *, + subtile_size: Optional[Sequence] = None, + async_type: Optional[int] = None, +) -> str: + """Attribute dict for memref.dma_start; stride lists as bracketed integer lists.""" + parts = [ + f"dram_stride = {dram_stride}", + f"sram_stride = {sram_stride}", + f"padding = {int(padding)}", + ] + if subtile_size: + parts.append(f"subtile_size = {subtile_size}") + av = int(async_type) if async_type is not None else 1 + parts.append(f"async = {av} : i64") + return "{" + ", ".join(parts) + "}" + + class ParallelLoopBuffer(IndentedBuffer): def indent(self, offset=1, attribute="", suffix=""): @contextlib.contextmanager @@ -167,7 +203,11 @@ def get_mlir_shape(info): def mlir_argdefs(self, extra_node=dict()): buffer_types = {} for x in V.graph.buffers: - if not isinstance(x.layout, MultiOutputLayout): # FIXME: MultiOutputLayout should be handled + if isinstance(x.layout, MultiOutputLayout): + # MultiOutput kernel containers own concrete output nodes in `outputs`. + for out in getattr(x, "outputs", []): + buffer_types[out.get_name()] = [out.get_dtype(), out.get_numel(), out.get_size(), out.get_stride()] + else: buffer_types[x.get_name()] = [x.get_dtype(), x.get_numel(), x.get_size(), x.get_stride()] for name, val in V.graph.graph_inputs.items(): if isinstance(val, sympy.Expr): @@ -244,17 +284,23 @@ def get_tile_stride_per_lane(self, tile_size: list[int], tile_stride: list[int]) return tile_stride def get_compute_vec_size(self, tile_size: list[int], reduction_numel: int, nr_rdim: int) -> int: - if self.forced_vec_size is not None: - return self.forced_vec_size - per_lane = self.get_numel_per_lane(tile_size) stride = self.vlane_stride if nr_rdim: val = per_lane // max(reduction_numel, 1) + result = val for mult in [8, 4, 2]: if per_lane >= val * mult: - return val * mult - return val + result = val * mult + break + if self.forced_vec_size is not None: + # Cap while keeping result divisible by val (= reduction_size). + # This preserves the assert(vec_len % reduction_size == 0) invariant. + capped = (min(result, self.forced_vec_size) // max(val, 1)) * max(val, 1) + result = max(capped, val) + return result + if self.forced_vec_size is not None: + return self.forced_vec_size for mult in [8, 4, 2]: if (per_lane // stride) >= mult: return stride * mult @@ -330,8 +376,8 @@ def _adjust_one(dim_size, tile_size): remain = candidate_tile_size[axis] % stride if remain: - candidate_tile_size[axis] += stride - remain - self.tile_constraint[axis].must_divide_dim = False + # #201: relax vlane_stride constraints + self.vmap.vlane_stride = 1 return candidate_tile_size def scale_tile_dim(self, axis, dim_sz, scale_factor=2): @@ -486,7 +532,7 @@ def __init__(self, tile_size, vector_lane, vlane_split_axis=None, vlane_stride=N self.name = "" self._tile_size = list(tile_size) self._tile_stride = None - self.tile_constraint = [TileConstraint(vlane_stride) for _ in tile_size] + self.tile_constraint = [TileConstraint(vlane_stride if idx == vlane_split_axis else 1) for idx, _ in enumerate(tile_size)] self.tile_axis_order = list(range(len(tile_size))) self.update_tile_stride() @@ -498,7 +544,7 @@ def __init__(self, tile_size, vector_lane, vlane_split_axis=None, vlane_stride=N vlane_stride=vlane_stride ) - self.implicit_dim_size = None + self.implicit_dim_size = {} self.nr_rdim = 0 self.offset = sympy.Integer(0) # Dram offset @@ -569,7 +615,6 @@ def __init__(self): # Default HW setting self.vector_lane = extension_config.vpu_num_lanes self.spad_info = extension_config.CONFIG_SPAD_INFO - self.precision = extension_config.CONFIG_PRECISION self.num_cores = extension_config.CONFIG_NUM_CORES self.vlen = extension_config.vpu_vector_length_bits @@ -588,6 +633,7 @@ def __init__(self, kernel_group, reason=None): self.ranges = None self.reduction_depth = None self.itervars = None + self.itervar_cses = None # Code buffer self.vector_compute = IndentedBuffer() self.reductions_suffix = IndentedBuffer() @@ -595,13 +641,19 @@ def __init__(self, kernel_group, reason=None): # MLIR SSA tracker self.var_info = {} # MLIR variable info self.buffer_types : dict = None # format: dtype, numel, size, stride - self.compute_idx = "compute_idx" + # Create compute idx + self.compute_idx = self.register_var_cse("compute_idx", 1, "index") self.compute_body_loop = LoopLevel(self.compute_idx, 1) self.prologue_compute_body_loop = LoopLevel(self.compute_idx, 1) self.recodegen = reason # spad overflow, tile size, vlane stride self.stop_autotune = False - def set_ranges(self, lengths, reduction_lengths): + instance_id = id(self) + self.target_buffer_override = contextvars.ContextVar(f"Handler_compute_override_{instance_id}", default=self.compute) + self.target_cse_override = contextvars.ContextVar(f"Handler_cse_override_{instance_id}", default=self.cse) + self._nested_context_depth = 0 + + def set_ranges(self, lengths, reduction_lengths, index_names=None): if self.call_ranges: assert self.call_ranges == tuple(lengths) + tuple( reduction_lengths @@ -610,7 +662,13 @@ def set_ranges(self, lengths, reduction_lengths): else: self.call_ranges = tuple(lengths) + tuple(reduction_lengths) self.ranges = [self.rename_indexing(x) for x in self.call_ranges] - self.itervars = [sympy.Symbol(f"index{n}") for n in range(len(self.ranges))] + if index_names is None: + self.itervars = [sympy.Symbol(f"index{n}") for n in range(len(self.ranges))] + else: + assert len(index_names) == len(self.ranges), f"Index names length mismatch: {len(index_names)} != {len(self.ranges)}" + self.itervars = [sympy.Symbol(str(n)) for n in index_names] + + self.itervar_cses = {str(index) : self.register_var_cse(str(index), 1, "index") for index in self.itervars} self.reduction_depth = len(lengths) return ( self.itervars[: self.reduction_depth], @@ -632,9 +690,14 @@ def store(self, name, index, value, mode=None): def reduction(self, dtype, src_dtype, reduction_type, value): raise NotImplementedError() - def indirect_indexing(self, index_var, size, check): + def indirect_indexing(self, index_var, size, check, wrap_neg): raise NotImplementedError() + def check_bounds(self, expr, size, lower, upper): + # MLIR backend currently relies on masked paths for out-of-bounds handling. + # Keep this hook as a no-op to satisfy Inductor's check_bounds callback. + return + def codegen_global_init(self): raise NotImplementedError() @@ -645,7 +708,7 @@ def call_kernel(self, kernel_name): wrapper = V.graph.wrapper_code _, call_args, _, _ = self.kernel_group.args.mlir_argdefs() # generate the code to call this - wrapper.generate_kernel_call(kernel_name, call_args, cuda=False) + wrapper.generate_kernel_call(kernel_name, call_args, triton=False) def is_modular_indexing(self, expr): return "ModularIndexing" in str(expr) @@ -679,7 +742,9 @@ def extract_dividers(self, implicit_ops): } new_index = operand.index.subs(subs_map) for arg in new_index.args: - if len(arg.free_symbols) != 1: + if arg.is_number: + continue + if len(arg.free_symbols) > 1: raise NotImplementedError("Not supporting this view operation...!") if arg.is_Mul and arg.args[0].is_number: arg = arg.args[1] @@ -709,13 +774,13 @@ def compute_tile_size(self, nodes, vars, reduction_vars): init_tile_desc.nr_rdim = len(reduction_vars) self.kernel_group.set_tile_info(init_tile_desc) - # Handle edge case - if len(self.ranges)==1 and self.ranges[0] == 1: # Scalar case 2 - self.kernel_group.tile_desc.vmap.vlane_stride = 1 - self.kernel_group.tile_desc.vmap.vlane_split_axis = 0 - elif vlane_split_axis == -1: # Reduction only case - self.kernel_group.tile_desc.vmap.vlane_split_axis = 0 - self.kernel_group.tile_desc.vmap.vlane_stride = self.kernel_group.tile_desc.get_tile_size()[0] + # Handle edge case + if len(self.ranges)==1 and self.ranges[0] == 1: # Scalar case 2 + self.kernel_group.tile_desc.vmap.vlane_stride = 1 + self.kernel_group.tile_desc.vmap.vlane_split_axis = 0 + elif vlane_split_axis == -1: # Reduction only case + self.kernel_group.tile_desc.vmap.vlane_split_axis = 0 + self.kernel_group.tile_desc.vmap.vlane_stride = self.kernel_group.tile_desc.get_tile_size()[0] # Handle implict dims. Input operand could be high dimension tensor. # Note: https://github.com/PSAL-POSTECH/PyTorchSim/issues/173 @@ -752,10 +817,24 @@ def codegen_nodes(self, nodes, kernel_name): # Set node range info vars, reduction_vars = self.set_ranges(group, reduction_group) tile_desc = self.compute_tile_size(nodes, vars, reduction_vars) + _, _, _, self.buffer_types = self.kernel_group.args.mlir_argdefs() + safe_vec_size = self.get_safe_vec_size(tile_desc.get_compute_vec_size()) + # For pointwise (non-reduction) kernels, cap the MLIR vector size so that + # f16->f32 widening stays within LMUL<=4 (step and forced_vec_size must match). + # Reduction kernels are left unchanged: their accumulator/multi_reduction + # structure assumes compute_vec_size == step, so we must not split them here. + tile_desc.vmap.forced_vec_size = safe_vec_size + compute_vec = tile_desc.get_compute_vec_size() + # RVV requires vector lengths that produce integer power-of-2 LMUL values. + # Non-power-of-2 element counts (e.g. 24) cause LLVM WidenVectorResult crashes. + # Raise BEFORE the try/except so this propagates to make_choices (not retried). + if compute_vec > 1 and (compute_vec & (compute_vec - 1)) != 0: + raise RecompileSignal( + f"Non-power-of-2 compute_vec_size {compute_vec}: tile rejected (RVV requires power-of-2 LMUL)" + ) self.compute_body_loop.size = tile_desc.get_numel_per_lane() - self.compute_body_loop.step = tile_desc.get_compute_vec_size() + self.compute_body_loop.step = compute_vec try: - _, _, _, self.buffer_types = self.kernel_group.args.mlir_argdefs() with self as kernel: for node in nodes: node.run(vars, reduction_vars) @@ -769,8 +848,8 @@ def codegen_nodes(self, nodes, kernel_name): V.graph.removed_buffers |= self.removed_buffers # V.graph.inplaced_to_remove |= self.inplaced_to_remove src_code = self.codegen_kernel(kernel_name=kernel_name) - self.meta_kernel() - return src_code + meta_code = self.meta_kernel() + return src_code, meta_code def codegen_kernel(self, kernel_name): arg_defs, _, _, _ = self.kernel_group.args.mlir_argdefs() @@ -783,46 +862,19 @@ def codegen_kernel(self, kernel_name): code.splice(self.codegen_global_init()) code.writeline(f'func.func @{kernel_decl_name}({arg_defs})') with code.indent(): - for old, new in self.kernel_group.args.aliases(): - code.writeline(f"auto {old} = {new};") # Loop body part code.splice(self.codegen_loops()) return code.getvalue() def meta_kernel(self): - wrapper = V.graph.wrapper_code _, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs() - wrapper.add_import_once('\nprint(f\'Wrapper Codegen Path = {__file__}\')') - # Dump loop and load/store information - wrapper.add_import_once(f"arg_attributes = {arg_attributes}") - return arg_attributes + meta_code = arg_attributes + return meta_code def get_constant_vector(self, expr): constant_vector = [[int(expr.coeff(var)),None] for var in self.itervars] return constant_vector - def get_constant_vector2(self, expr): - # Case 0. symbol ex) index 0 - # Case 1. inner product form ex) 16 * index0 + 1 * index1 - # Case 2. Complicated form ex) 16 * index0 + 8 * (index//4) + (index % 4) - constant_vector = [] - if expr.is_symbol: - constant_vector.append(tuple([1, expr])) - return constant_vector - - for arg in expr.args: - if arg.is_symbol: - constant_vector.append(tuple([1,arg])) - continue - if len(arg.args) == 0: #TODO: check this - continue - if arg.args[0].is_number: - constant_vector.append(arg.args) - else: - constant_vector.append([1, arg]) - - return constant_vector - def find_node_by_name(self, name): if name in V.graph.graph_inputs: return V.graph.graph_inputs[name] @@ -837,6 +889,11 @@ def is_scalar(self, name): def roundup_vectorlane(self, size, amp=1): return ((size + self.vector_lane - 1) // self.vector_lane) * self.vector_lane * amp + def register_var_cse(self, name, size, dtype): + var = self.create_cse_var(name, ValueRanges.unknown()) + self.register_var_info(var, [size, dtype]) + return var + def register_var_info(self, var, var_info): self.var_info[var] = var_info @@ -845,6 +902,18 @@ def rename_indexing(self, index) -> sympy.Expr: # and renames variables in index expressions to kernel arg names if isinstance(index, (list, tuple)): return [self.rename_indexing(x) for x in index] + + # FIXME. This is a temporary solution to remove Identity wrappers from index expression. + # Remove Identity wrappers from index expression + # Check if index itself is Identity + if isinstance(index, Identity): + index = index.args[0] if index.args else index + + # Replace Identity arguments with Identity.args[0] + Identity_args = [expr for expr in sympy.preorder_traversal(index) if isinstance(expr, Identity)] + for expr in Identity_args: + index = index.replace(expr, expr.args[0] if expr.args else expr) + index = V.graph.sizevars.simplify(index) sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) replacements = { @@ -854,6 +923,27 @@ def rename_indexing(self, index) -> sympy.Expr: } return sympy_subs(index, replacements) + @contextmanager + def override_buffer_cse(self, *, buffer=None, cse=None): + buffer_override = self.target_buffer_override + cse_override = self.target_cse_override + buffer_token = cse_token = None + try: + # Store tokens for proper restoration in nested contexts + # contextvars.set() returns the previous value (token) which can be used for reset() + if buffer is not None: + buffer_token = buffer_override.set(buffer) + if cse is not None: + cse_token = cse_override.set(cse) + yield self + finally: + # Restore using tokens - contextvars automatically handles nested contexts + # Each level restores to its own previous value + if cse_token is not None: + cse_override.reset(cse_token) + if buffer_token is not None: + buffer_override.reset(buffer_token) + def __enter__(self): class CSEProxy: self.name = "CSEProxy" @@ -861,27 +951,38 @@ class CSEProxy: @staticmethod def __getattr__(name: str) -> Callable[..., common.CSEVariable]: # type: ignore[misc] def inner(*args, **kwargs): - code, ret_info = getattr(parent_handler, name)(*args, var_info=self.var_info) - csevar = self.cse.generate( - self.compute, - code, - bounds=ValueRanges.unknown(), - assignment=(ret_info[0] is not None) - ) - if ret_info[0] is not None: - self.register_var_info(csevar, ret_info) - csevar.update_on_args(name, args, kwargs) + code, ret_info = getattr(parent_handler, name)(*args, **kwargs) + target_buffer = self.target_buffer_override.get() + target_cse = self.target_cse_override.get() + if isinstance(code, common.DeferredLine): + target_buffer.writeline(code) + return None + else: + csevar = target_cse.generate( + target_buffer, + code, + bounds=ValueRanges.unknown(), + assignment=(ret_info[0] is not None) + ) + if ret_info[0] is not None: + self.register_var_info(csevar, ret_info) + csevar.update_on_args(name, args, kwargs) return csevar return inner @staticmethod - def indirect_indexing(index_var, size, check=True): + def indirect_indexing(index_var, size, check=True, wrap_neg=True): # Skip CSE since this doesn't return an expression - return self.indirect_indexing(index_var, size, check) + return self.indirect_indexing(index_var, size, check, wrap_neg) + + @staticmethod + def check_bounds(index, size, lower, upper): + return self.check_bounds(index, size, lower, upper) @staticmethod def load(name: str, index: sympy.Expr): + index = self.rename_indexing(index) if name in self.cse.invalidated_stores: # A load from an invalidated store requires us to # keep the actual buffer around @@ -892,10 +993,10 @@ def load(name: str, index: sympy.Expr): if name in store_cache: return store_cache[name] key = name+str(index) - if key not in self.cse.cache: + if key not in self.cse._cache: result = self.load(name, index) - self.cse.cache[key] = result - return self.cse.cache[key] + self.cse._cache[key] = result + return self.cse._cache[key] @staticmethod def store(name, index, value, mode=None): @@ -903,9 +1004,10 @@ def store(name, index, value, mode=None): if mode is None: self.cse.store_cache[name] = value if self.current_node: - for other_name in self.current_node.get_mutations(): + for other_name in self.current_node.get_output(name).get_mutations(): self.cse.store_cache[other_name] = value if name not in V.graph.removed_buffers: + index = self.rename_indexing(index) return self.store(name, index, value, mode=mode) @staticmethod @@ -913,22 +1015,28 @@ def store_reduction(name, index, value): self.store_buffer_names.add(name) self.cse.store_cache[name] = value if self.current_node: - for other_name in self.current_node.get_mutations(): + for other_name in self.current_node.get_output(name).get_mutations(): self.cse.store_cache[other_name] = value if name not in V.graph.removed_buffers: + index = self.rename_indexing(index) return self.store_reduction(name, index, value) @staticmethod def reduction(dtype, src_dtype, reduction_type, value): return self.reduction(dtype, src_dtype, reduction_type, value) + @staticmethod + def check_bounds(index, size, lower, upper): + return self.check_bounds(index, size, lower, upper) + @staticmethod def _index_expr(tile_size, buffer, renamed_expression, index): return self._index_expr(tile_size, buffer, renamed_expression, index) @staticmethod def index_expr(index, dtype): + index = self.rename_indexing(index) return self.index_expr(index, dtype) @staticmethod @@ -957,13 +1065,56 @@ def bucketize( values, offsets_name, offsets_size, indexing_dtype, right ) - super().__enter__() - assert self.overrides - parent_handler = self.overrides(V.get_ops_handler()) - self.exit_stack.enter_context(V.set_ops_handler(CSEProxy())) - self.exit_stack.enter_context(V.set_kernel_handler(self)) + if self._nested_context_depth == 0: + self.exit_stack.__enter__() + assert self.overrides + parent_handler = self.overrides() + + self.exit_stack.enter_context(V.set_ops_handler(CSEProxy())) + self.exit_stack.enter_context(V.set_kernel_handler(self)) + self._nested_context_depth += 1 return self + def __exit__(self, exc_type, exc_val, exc_tb): + self._nested_context_depth -= 1 + if self._nested_context_depth == 0: + super().__exit__(exc_type, exc_val, exc_tb) + + def get_safe_vec_size(self, default_vec_size: int = 64) -> int: + """ + Cap forced vector size for low-precision paths so widening ops + (e.g., f16/bf16 -> f32) do not exceed RVV LMUL limits. + + Widening is legal up to source LMUL<=4 (destination LMUL<=8). + Using RVV relation LMUL = (SEW * VL) / VLEN, the safe source VL is: + VL <= 4 * VLEN / SEW + """ + + if not hasattr(self, "buffer_types") or not self.buffer_types: + return default_vec_size + + lowp_bits = [] + for info in self.buffer_types.values(): + dtype = info[0] if info else None + if dtype in DTYPE_LOWP_FP: + mlir_dtype = DTYPE_TO_MLIR[dtype] + lowp_bits.append(MLIR_TO_BIT[mlir_dtype]) + + if not lowp_bits: + return default_vec_size + + min_lowp_bits = min(lowp_bits) + # Constraint: Vector element count must be compatible across all types. + # VLEN=256: f16 (LMUL=2) and f32 (LMUL=4) both yield 32 elements. + # Note: Gem5 version restricts widening ops to LMUL < 8 for destination registers. + # Max LMUL set to 1 to ensure compatibility/safety. + + widen_safe_cap = self.vlen // min_lowp_bits + if widen_safe_cap <= 0: + return default_vec_size + + vec_size = min(default_vec_size, widen_safe_cap) + return vec_size @dataclasses.dataclass class LoopLevel: diff --git a/PyTorchSimFrontend/mlir/mlir_conv_common.py b/PyTorchSimFrontend/mlir/mlir_conv_common.py index a1a9d935..d577dbd8 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_common.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_common.py @@ -2,7 +2,7 @@ import math from typing import List, Optional -from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs +from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs, get_dtype_nbytes from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel from torch._inductor.ir import IRNode @@ -12,6 +12,9 @@ class MLIRConvCommonTemplate(MLIRTemplate): WRAPPER_TEMPLATE = None def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): super().__init__("kernel", input_nodes, layout, input_reorder) + self.support_epilogue_fusion = True + self.support_prologue_fusion = False + self.support_reduction_fusion = False self.stride = kwargs["stride"] self.padding = kwargs["padding"] self.dilation = kwargs["dilation"] @@ -37,7 +40,7 @@ def render(self, **kwargs): raise NotImplementedError() - def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes): raise NotImplementedError() def extract_info(self, kernel, template_buffer_node, epilogue_nodes): @@ -49,6 +52,13 @@ def extract_info(self, kernel, template_buffer_node, epilogue_nodes): X, W = self.input_nodes[0], self.input_nodes[1] Y = self.output_node Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + dtype_infos = [("X", X.get_dtype()), ("W", W.get_dtype()), ("Y", Y.get_dtype())] + if Bias is not None: + dtype_infos.append(("Bias", Bias.get_dtype())) + if len({dtype for _, dtype in dtype_infos}) != 1: + dtype_desc = ", ".join(f"{name}={dtype}" for name, dtype in dtype_infos) + raise NotImplementedError(f"Mixed dtype Conv is not implemented yet ({dtype_desc})") + precision_bytes = get_dtype_nbytes(X.get_dtype()) if epilogue_nodes is not None: extra_node_rw = { @@ -66,7 +76,7 @@ def extract_info(self, kernel, template_buffer_node, epilogue_nodes): PADDING_W=self.padding[1] STRIDE_H=self.stride[0] STRIDE_W=self.stride[1] - return X,W,Y,Bias,n_extra_node,BATCH,I_C,I_H,I_W,O_C,K_H,K_W,O_H,O_W,PADDING_H,PADDING_W,STRIDE_H,STRIDE_W + return X,W,Y,Bias,n_extra_node,BATCH,I_C,I_H,I_W,O_C,K_H,K_W,O_H,O_W,PADDING_H,PADDING_W,STRIDE_H,STRIDE_W,precision_bytes def get_tile_candidates(self, kernel: MLIRTemplateKernel, @@ -74,19 +84,18 @@ def get_tile_candidates(self, epilogue_nodes: Optional[List[IRNode]] = None, **kwargs): # Extract input arguments info - X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) - return self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W, precision_bytes = self.extract_info(kernel, template_buffer_node, epilogue_nodes) + return self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes) def outer_func_render(self, kernel_name, input_args): X, W = self.input_nodes[0], self.input_nodes[1] Y = self.output_node Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - eager_mode = int(os.environ.get('TOGSIM_EAGER_MODE', default=False)) options = dict( kernel=self.kernel, KERNEL_NAME=kernel_name, - FUNC_NAME=self.function_name + f"_{len(input_args)}", + FUNC_NAME="wrapper_" + kernel_name, INPUT=X, WEIGHT=W, BIAS=Bias, @@ -94,11 +103,10 @@ def outer_func_render(self, kernel_name, input_args): PADDING_H=self.padding[0], PADDING_W=self.padding[1], VALIDATION_MODE=extension_config.pytorchsim_functional_mode, - TOGSIM_EAGER_MODE=eager_mode, input_reorder=self.input_reorder ) code = self._template_from_string(self.WRAPPER_TEMPLATE).render(**options) - return code, self.function_name + f"_{len(input_args)}" + return code, "wrapper_" + kernel_name def get_arg_attributes(self): arg_attributes = [] @@ -115,6 +123,6 @@ def compute_stride(shape): return stride X_stride = compute_stride(X_shape) - arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) + arg_attributes.append([X.get_name(), [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) return arg_attributes diff --git a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py index 0bf01421..8b8288a8 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py @@ -47,7 +47,7 @@ {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> %c0 = arith.constant 0 : index {{- kernel.def_local_vars(indent_size=2) }} @@ -59,7 +59,7 @@ {%- if BIAS %} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N, TILE_O_H, TILE_O_W], indent_size=10) }} {%- else %} - affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> {%- endif %} affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { affine.for %tile_k = 0 to {{ I_C * K_W }} step {{ TILE_K }} { @@ -71,16 +71,16 @@ affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. affine.for %tile_k_w = 0 to 1 { %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) - %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> affine.for %tile_o_h = 0 to {{ TILE_O_H }} { affine.for %tile_o_w = 0 to {{ TILE_O_W }} { %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_o_w) %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) - %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> - %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) } { inner_loop=true } } { inner_loop=true } } { inner_loop=true } @@ -104,7 +104,7 @@ def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: padded_shape = list(X.shape) padded_shape[2] += 2 * {{ PADDING_H }} padded_shape[3] += 2 * {{ PADDING_W }} - X_padding = torch.zeros(padded_shape, device=X.device) + X_padding = torch.zeros(padded_shape).to(device=X.device) X_padding[:, :, {{ PADDING_H }}:X.shape[2] + {{ PADDING_H }}, {{ PADDING_W }}:X.shape[3] + {{ PADDING_W }}] = X # Tanspose inputs @@ -120,9 +120,6 @@ def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: # Launch kernel {{ KERNEL_NAME }} - {%- if TOGSIM_EAGER_MODE %} - yield ({{KERNEL_NAME}}, ) - {%- endif %} """ def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): super().__init__(input_nodes, layout, input_reorder, **kwargs) @@ -134,12 +131,12 @@ def render(self, tile_info = None, **kwargs): # Extract input arguments info - X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W, precision_bytes = self.extract_info(kernel, template_buffer_node, epilogue_nodes) # Select tile size adn template conv_template = CONV_TEMPLATE if tile_info is None: - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)[0] + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes)[0] else: TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N @@ -173,7 +170,7 @@ def render(self, Y_tile_desc.set_name("output_buffer") Y_dim = [Symbol("tile_m"), Symbol("tile_n"), Symbol("o_h"), Symbol("o_w")] Y_idx = [Y_dim[0]*O_C*O_H*O_W, Y_dim[1]*O_H*O_W, Y_dim[2]*O_W, Y_dim[3]] - + # Extract Bias info Bias_idx = [Number(0), Symbol("tile_n"), Number(0), Number(0)] Bias_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) @@ -182,6 +179,8 @@ def render(self, if Bias is not None: Bias_tile_desc.offset = Bias.get_layout().offset + data_stype = mlir_common.DTYPE_TO_MLIR[X.get_dtype()] + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -223,7 +222,7 @@ def render(self, X_idx = X_idx, W_idx = W_idx, Bias_idx = Bias_idx, - DATA_STYPE="f32", + DATA_STYPE=data_stype, input_reorder=self.input_reorder ) @@ -240,8 +239,8 @@ def render(self, kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) return code - def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): - tile_candidates = kernel.conv_multi_tile_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes): + tile_candidates = kernel.conv_multi_tile_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node, precision_bytes=precision_bytes) for idx, (TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py index 92b9a525..92efff66 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py @@ -48,7 +48,7 @@ {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> %c0 = arith.constant 0 : index {{- kernel.def_local_vars(indent_size=2) }} affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { @@ -58,7 +58,7 @@ {%- if BIAS %} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[1, SUB_TILE_N, TILE_O_H, SUB_TILE_M], indent_size=8) }} {%- else %} - affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> {%- endif %} affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { affine.for %k_w = 0 to {{ K_W }} step {{ TILE_K_W }} { @@ -72,16 +72,16 @@ affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. affine.for %tile_k_w = 0 to {{ TILE_K_W }} { %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) - %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> affine.for %tile_o_h = 0 to {{ TILE_O_H }} { affine.for %tile_o_w = 0 to {{ 1 }} { // TILE_O_W %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_k_w) %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) - %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> - %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) } { inner_loop=true } } { inner_loop=true } } { inner_loop=true } @@ -105,7 +105,7 @@ def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: padded_shape = list(X.shape) padded_shape[2] += 2 * {{ PADDING_H }} padded_shape[3] += 2 * {{ PADDING_W }} - X_padding = torch.zeros(padded_shape, device=X.device) + X_padding = torch.zeros(padded_shape).to(device=X.device) X_padding[:, :, {{ PADDING_H }}:X.shape[2] + {{ PADDING_H }}, {{ PADDING_W }}:X.shape[3] + {{ PADDING_W }}] = X # Tanspose inputs @@ -121,9 +121,6 @@ def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: # Launch kernel {{ KERNEL_NAME }} - {%- if TOGSIM_EAGER_MODE %} - yield ({{KERNEL_NAME}}, ) - {%- endif %} """ def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): super().__init__(input_nodes, layout, input_reorder, **kwargs) @@ -135,12 +132,12 @@ def render(self, tile_info = None, **kwargs): # Extract input arguments info - X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W, precision_bytes = self.extract_info(kernel, template_buffer_node, epilogue_nodes) # Select tile size adn template conv_template = CONV_TEMPLATE if tile_info is None: - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)[0] + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes)[0] else: TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N @@ -181,6 +178,8 @@ def render(self, if Bias is not None: Bias_tile_desc.offset = Bias.get_layout().offset + data_stype = mlir_common.DTYPE_TO_MLIR[X.get_dtype()] + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -222,7 +221,7 @@ def render(self, X_idx = X_idx, W_idx = W_idx, Bias_idx = Bias_idx, - DATA_STYPE="f32", + DATA_STYPE=data_stype, input_reorder=self.input_reorder ) @@ -239,8 +238,8 @@ def render(self, kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) return code - def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): - tile_candidates = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, 1, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes): + tile_candidates = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, 1, O_H, O_W, self.stride, self.dilation, n_extra_node, precision_bytes=precision_bytes) # TODO: implement K_W for idx, (TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py index ab124852..dfd418d9 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py @@ -48,7 +48,7 @@ {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> %c0 = arith.constant 0 : index {{- kernel.def_local_vars(indent_size=2) }} @@ -59,7 +59,7 @@ {%- if BIAS %} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[1, SUB_TILE_N, TILE_O_H, SUB_TILE_M], indent_size=8) }} {%- else %} - affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> {%- endif %} affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { affine.for %k_w = 0 to {{ K_W }} step {{ TILE_K_W }} { @@ -72,16 +72,16 @@ affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. affine.for %tile_k_w = 0 to {{ TILE_K_W }} { %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) - %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> affine.for %tile_o_h = 0 to {{ TILE_O_H }} { affine.for %tile_o_w = 0 to {{ 1 }} { // TILE_O_W %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_k_w) %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) - %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> - %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) } { inner_loop=true } } { inner_loop=true } } { inner_loop=true } @@ -105,7 +105,7 @@ def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: padded_shape = list(X.shape) padded_shape[2] += 2 * {{ PADDING_H }} padded_shape[3] += 2 * {{ PADDING_W }} - X_padding = torch.zeros(padded_shape, device=X.device) + X_padding = torch.zeros(padded_shape).to(device=X.device) X_padding[:, :, {{ PADDING_H }}:X.shape[2] + {{ PADDING_H }}, {{ PADDING_W }}:X.shape[3] + {{ PADDING_W }}] = X # Tanspose inputs @@ -121,9 +121,6 @@ def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: # Launch kernel {{ KERNEL_NAME }} - {%- if TOGSIM_EAGER_MODE %} - yield ({{KERNEL_NAME}}, ) - {%- endif %} """ def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): super().__init__(input_nodes, layout, input_reorder, **kwargs) @@ -135,12 +132,12 @@ def render(self, tile_info = None, **kwargs): # Extract input arguments info - X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W, precision_bytes = self.extract_info(kernel, template_buffer_node, epilogue_nodes) # Select tile size adn template conv_template = CONV_TEMPLATE if tile_info is None: - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)[0] + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes)[0] else: TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N @@ -182,6 +179,8 @@ def render(self, if Bias is not None: Bias_tile_desc.offset = Bias.get_layout().offset + data_stype = mlir_common.DTYPE_TO_MLIR[X.get_dtype()] + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -223,7 +222,7 @@ def render(self, X_idx = X_idx, W_idx = W_idx, Bias_idx = Bias_idx, - DATA_STYPE="f32", + DATA_STYPE=data_stype, input_reorder=self.input_reorder ) @@ -240,8 +239,8 @@ def render(self, kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) return code - def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): - tile_candidates = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes): + tile_candidates = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node, precision_bytes=precision_bytes) # TODO: implement K_W for idx, (TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] diff --git a/PyTorchSimFrontend/mlir/mlir_conv_template.py b/PyTorchSimFrontend/mlir/mlir_conv_template.py index 66aa0a27..178ba7c6 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_template.py @@ -48,7 +48,7 @@ {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> %c0 = arith.constant 0 : index {{ kernel.def_local_vars(indent_size=2) }} @@ -60,7 +60,7 @@ {%- if BIAS %} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N, TILE_O_H, TILE_O_W], indent_size=10) }} {%- else %} - affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> {%- endif %} affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { affine.for %k_w = 0 to {{ K_W }} step {{ TILE_K_W }} { @@ -74,17 +74,17 @@ affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. affine.for %tile_k_w = 0 to {{ TILE_K_W }} { %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) - %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> affine.for %tile_o_h = 0 to {{ TILE_O_H }} { affine.for %tile_o_w = 0 to {{ TILE_O_W }} { %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) %tile_i_w = affine.apply #map_I_W(%tile_o_w, %tile_k_w) %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_i_w) %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) - %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> - %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) } { inner_loop=true } } { inner_loop=true } } { inner_loop=true } @@ -109,7 +109,7 @@ def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: padded_shape = list(X.shape) padded_shape[2] += 2 * {{ PADDING_H }} padded_shape[3] += 2 * {{ PADDING_W }} - X_padding = torch.zeros(padded_shape, device=X.device) + X_padding = torch.zeros(padded_shape).to(device=X.device) X_padding[:, :, {{ PADDING_H }}:X.shape[2] + {{ PADDING_H }}, {{ PADDING_W }}:X.shape[3] + {{ PADDING_W }}] = X # Tanspose inputs @@ -125,9 +125,6 @@ def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: # Launch kernel {{ KERNEL_NAME }} - {%- if TOGSIM_EAGER_MODE %} - yield ({{KERNEL_NAME}}, ) - {%- endif %} """ def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): super().__init__(input_nodes, layout, input_reorder, **kwargs) @@ -139,12 +136,12 @@ def render(self, tile_info = None, **kwargs): # Extract input arguments info - X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W, precision_bytes = self.extract_info(kernel, template_buffer_node, epilogue_nodes) # Select tile size adn template conv_template = CONV_TEMPLATE if tile_info is None: - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)[0] + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes)[0] else: TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info TOG_latency = BATCH if TILE_M > BATCH else TILE_M @@ -186,6 +183,8 @@ def render(self, if Bias is not None: Bias_tile_desc.offset = Bias.get_layout().offset + data_stype = mlir_common.DTYPE_TO_MLIR[X.get_dtype()] + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -227,7 +226,7 @@ def render(self, X_idx = X_idx, W_idx = W_idx, Bias_idx = Bias_idx, - DATA_STYPE="f32", + DATA_STYPE=data_stype, input_reorder=self.input_reorder ) @@ -244,8 +243,8 @@ def render(self, kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) return code - def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): - tile_candidates = kernel.conv_combination_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes): + tile_candidates = kernel.conv_combination_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node, precision_bytes=precision_bytes) for idx, (TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] diff --git a/PyTorchSimFrontend/mlir/mlir_decomposition.py b/PyTorchSimFrontend/mlir/mlir_decomposition.py new file mode 100644 index 00000000..0f443cf8 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_decomposition.py @@ -0,0 +1,372 @@ +import math +import operator +from typing import Optional, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F +from torch._inductor.decomposition import register_decomposition + +aten = torch.ops.aten # only for @register_decomposition target + + +def _pair_2d(seq: Sequence[int]) -> Tuple[int, int]: + if len(seq) == 1: + v = int(seq[0]) + return v, v + return int(seq[0]), int(seq[1]) + + +def _int_eq(x, v: int) -> bool: + try: + return int(x) == v + except (TypeError, ValueError): + return False + + +def _can_rewrite_pointwise_conv_on_1x1_spatial_to_linear( + input: torch.Tensor, + weight: torch.Tensor, + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], + transposed: bool, + output_padding: Sequence[int], + groups: int, +) -> bool: + """ + Whether this ``aten.convolution`` is **exactly** ``F.linear`` on ``[N, C]`` (then reshaped + to ``[N, C_out, 1, 1]``): 1x1 kernel, spatial size 1x1, ``groups==1``, stride 1, no padding, + dilation 1 (typical SE line after global pool). + + If True, use ``_apply_pointwise_conv_on_1x1_spatial_as_linear``; if False, keep normal conv. + """ + if transposed or input.dim() != 4 or weight.dim() != 4: + return False + if groups != 1: + return False + if not ( + _int_eq(input.shape[2], 1) + and _int_eq(input.shape[3], 1) + and _int_eq(weight.shape[2], 1) + and _int_eq(weight.shape[3], 1) + ): + return False + + sh, sw = _pair_2d(stride) + ph, pw = _pair_2d(padding) + dh, dw = _pair_2d(dilation) + if sh != 1 or sw != 1 or ph != 0 or pw != 0 or dh != 1 or dw != 1: + return False + if len(output_padding) and any(not _int_eq(o, 0) for o in output_padding): + return False + + _, cin, _, _ = input.shape + _, cin_w, _, _ = weight.shape + try: + if int(cin_w) != int(cin): + return False + except (TypeError, ValueError): + return False + return True + + +def _apply_pointwise_conv_on_1x1_spatial_as_linear( + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + """Same numerics as ``convolution``; call only when ``_can_rewrite_...`` is True.""" + n, cin, _, _ = input.shape + cout, _, _, _ = weight.shape + x = input.reshape(n, cin) + w = weight.reshape(cout, cin) + return F.linear(x, w, bias).reshape(n, cout, 1, 1) + + +def _group_conv_cin1_cout1( + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + stride: Tuple[int, ...], + padding: Tuple[int, ...], + dilation: Tuple[int, ...], + groups: int, +) -> torch.Tensor: + """ + Grouped conv with ``Cin//groups == 1`` and ``Cout//groups == 1`` (input ``[N,G,H,W]``, weight ``[G,1,Kh,Kw]``). + + 1. Symmetric spatial padding on the input. + 2. For each kernel position ``(kh, kw)``, gather the output grid from the padded tensor and + multiply by ``weight[:, 0, kh, kw]`` (broadcast over ``N``), then sum over ``(kh, kw)``. + + Note + ---- + This is not a performance-optimized kernel: it is explicit gather–multiply–accumulate over + kernel elements. For competitive performance, add a dedicated template (or fused) kernel + instead of relying on this decomposition. + """ + n, c_in, _, _ = input.shape + # PyTorch layout: ``[Cout, Cin/groups, Kh, Kw]`` i.e. ``[G, 1, Kh, Kw]`` here. + c_out, cin_pg, kh, kw = weight.shape + g = groups + assert c_in == g and c_out == g and cin_pg == 1, (c_in, c_out, cin_pg, g) + + sh, sw = _pair_2d(stride) + ph, pw = _pair_2d(padding) + d_h, d_w = _pair_2d(dilation) + + # (left, right, top, bottom) for last two dims + x_pad = F.pad(input, (pw, pw, ph, ph)) + _, _, hp, wp = x_pad.shape + + h_out = (hp - d_h * (kh - 1) - 1) // sh + 1 + w_out = (wp - d_w * (kw - 1) - 1) // sw + 1 + + out = torch.zeros(n, g, h_out, w_out, dtype=input.dtype, device=input.device) + for ki in range(kh): + rows = torch.arange(h_out, device=input.device, dtype=torch.long) * sh + ki * d_h + for kj in range(kw): + cols = torch.arange(w_out, device=input.device, dtype=torch.long) * sw + kj * d_w + sub = x_pad[:, :, rows[:, None], cols[None, :]] + wgk = weight[:, 0, ki, kj].reshape(1, g, 1, 1) + out = out + sub * wgk + + if bias is not None: + out = out + bias.reshape(1, g, 1, 1) + return out + + +@register_decomposition(aten.convolution.default) +def decompose_convolution( + input: torch.Tensor, + weight: torch.Tensor, + bias: Union[torch.Tensor, None], + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], + transposed: bool, + output_padding: Sequence[int], + groups: Union[int, torch.SymInt], +): + """ + 1. Pointwise 1x1 on spatial 1x1 (groups==1): rewrite to F.linear so backends + that struggle with tiny spatial convs (e.g. SE after AdaptiveAvgPool2d(1)) see + aten.mm / linear lowering instead. + + 2. Grouped conv when Cin//groups == Cout//groups == 1: _group_conv_cin1_cout1. + + Otherwise returns NotImplemented (Inductor uses the default aten.convolution). + + Note + ---- + The grouped path is not performance-optimized; it exists for correctness experiments. + """ + try: + gcount = operator.index(groups) + except (TypeError, ValueError): + return NotImplemented + + if _can_rewrite_pointwise_conv_on_1x1_spatial_to_linear( + input, + weight, + stride, + padding, + dilation, + transposed, + output_padding, + gcount, + ): + return _apply_pointwise_conv_on_1x1_spatial_as_linear(input, weight, bias) + + # groups==1, non-1x1 spatial: keep default aten.convolution (plain conv). + if gcount == 1: + return NotImplemented + + cin = input.shape[1] + cout = weight.shape[0] + cin_pg = cin // gcount + cout_pg = cout // gcount + supported = ( + not transposed + and cin % gcount == 0 + and cout % gcount == 0 + and cin_pg == 1 + and cout_pg == 1 + and weight.shape[1] == 1 + ) + if not supported: + raise NotImplementedError( + "PyTorchSim aten.convolution decomposition supports grouped conv only when " + "Cin//groups == 1 and Cout//groups == 1 (i.e. per-group Cin and Cout are 1). " + "For general group convolution, use the default kernel or a dedicated template kernel." + ) + return _group_conv_cin1_cout1( + input, + weight, + bias, + tuple(stride), + tuple(padding), + tuple(dilation), + gcount, + ) + +@register_decomposition(aten._native_multi_head_attention.default) +def decompose_native_multi_head_attention( + query, + key, + value, + embed_dim: int, + num_heads: int, + qkv_weight, + qkv_bias, + proj_weight, + proj_bias, + mask=None, + need_weights: bool = False, +): + """ + Decompose _native_multi_head_attention into scaled_dot_product_attention operations. + + Based on F.scaled_dot_product_attention and nn.MultiheadAttention implementation: + 1. QKV projection (if needed - but query/key/value may already be projected) + 2. Reshape to multi-head format + 3. Scaled dot product: Q @ K^T / sqrt(head_dim) + 4. Softmax + 5. Attention @ V + 6. Reshape back and output projection + """ + head_dim = embed_dim // num_heads + scale_factor = 1.0 / math.sqrt(head_dim) + + # Get input shapes - assuming [batch, seq_len, embed_dim] format + query_shape = query.shape + if len(query_shape) == 3: + # [batch, seq_len, embed_dim] format + batch_size = query_shape[0] + seq_len = query_shape[1] + elif len(query_shape) == 2: + # [seq_len, embed_dim] -> add batch dimension + batch_size = 1 + seq_len = query_shape[0] + query = query.unsqueeze(0) # [1, seq_len, embed_dim] + key = key.unsqueeze(0) + value = value.unsqueeze(0) + else: + # Fallback: assume first dim is batch, second is seq_len + batch_size = query_shape[0] if len(query_shape) > 0 else 1 + seq_len = query_shape[1] if len(query_shape) > 1 else query_shape[0] + + # Step 1: QKV projection (if query/key/value are not already projected) + # In many cases, query/key/value are already projected, so we check if qkv_weight is used + # For now, assume they might need projection + # Note: In practice, _native_multi_head_attention often receives already projected inputs + + # Reshape for projection: [batch, seq_len, embed_dim] -> [batch*seq_len, embed_dim] + if len(query.shape) == 3: + query_flat = query.view(-1, embed_dim) + key_flat = key.view(-1, embed_dim) + value_flat = value.view(-1, embed_dim) + else: + query_flat = query + key_flat = key + value_flat = value + + # QKV projection using qkv_weight and qkv_bias + # Check if GQA (Grouped Query Attention) is used + # Standard MHA: qkv_weight shape = [3*embed_dim, embed_dim] + # GQA: qkv_weight shape = [embed_dim + 2*kv_embed_dim, embed_dim] where kv_embed_dim < embed_dim + qkv_weight_total = qkv_weight.shape[0] + + # Determine if GQA: if qkv_weight is not exactly 3*embed_dim, it might be GQA + if qkv_weight_total == 3 * embed_dim: + # Standard MHA: split equally + qkv_weight_q, qkv_weight_k, qkv_weight_v = torch.split(qkv_weight, embed_dim, dim=0) + if qkv_bias is not None: + qkv_bias_q, qkv_bias_k, qkv_bias_v = torch.split(qkv_bias, embed_dim, dim=0) + else: + qkv_bias_q = qkv_bias_k = qkv_bias_v = None + kv_embed_dim = embed_dim + kv_heads = num_heads + else: + # GQA: Q has embed_dim, K and V share the rest + # Assume Q = embed_dim, K = V = (qkv_weight_total - embed_dim) / 2 + q_dim = embed_dim + kv_dim = (qkv_weight_total - embed_dim) // 2 + qkv_weight_q = qkv_weight[:q_dim] + qkv_weight_k = qkv_weight[q_dim:q_dim + kv_dim] + qkv_weight_v = qkv_weight[q_dim + kv_dim:] + if qkv_bias is not None: + qkv_bias_q = qkv_bias[:q_dim] + qkv_bias_k = qkv_bias[q_dim:q_dim + kv_dim] + qkv_bias_v = qkv_bias[q_dim + kv_dim:] + else: + qkv_bias_q = qkv_bias_k = qkv_bias_v = None + kv_embed_dim = kv_dim + kv_heads = kv_embed_dim // head_dim # Number of KV heads + + # Project Q, K, V + q = torch.nn.functional.linear(query_flat, qkv_weight_q, qkv_bias_q) + k = torch.nn.functional.linear(key_flat, qkv_weight_k, qkv_bias_k) + v = torch.nn.functional.linear(value_flat, qkv_weight_v, qkv_bias_v) + + # Reshape back: [batch*seq_len, embed_dim] -> [batch, seq_len, embed_dim] + q = q.view(batch_size, seq_len, embed_dim) + k = k.view(batch_size, seq_len, kv_embed_dim) + v = v.view(batch_size, seq_len, kv_embed_dim) + + # Step 2: Reshape to multi-head format + # [batch, seq_len, embed_dim] -> [batch, seq_len, num_heads, head_dim] + q = q.view(batch_size, seq_len, num_heads, head_dim) + k = k.view(batch_size, seq_len, kv_heads, head_dim) + v = v.view(batch_size, seq_len, kv_heads, head_dim) + + # Transpose to [batch, num_heads, seq_len, head_dim] for bmm + q = q.transpose(1, 2) # [batch, num_heads, seq_len, head_dim] + k = k.transpose(1, 2) # [batch, kv_heads, seq_len, head_dim] + v = v.transpose(1, 2) # [batch, kv_heads, seq_len, head_dim] + + # GQA: If key/value have fewer heads, repeat them to match query heads + if kv_heads < num_heads: + repeat_factor = num_heads // kv_heads + k = k.repeat_interleave(repeat_factor, dim=1) # [batch, num_heads, seq_len, head_dim] + v = v.repeat_interleave(repeat_factor, dim=1) # [batch, num_heads, seq_len, head_dim] + + # Step 3: Scaled dot product attention + # Scale Q + q_scaled = q * scale_factor + + # Q @ K^T: [batch, num_heads, seq_len, head_dim] @ [batch, num_heads, head_dim, seq_len] + # -> [batch, num_heads, seq_len, seq_len] + k_transposed = k.transpose(-2, -1) # [batch, num_heads, head_dim, seq_len] + scores = torch.matmul(q_scaled, k_transposed) # [batch, num_heads, seq_len, seq_len] + + # Step 4: Apply mask if provided + if mask is not None: + if mask.dtype == torch.bool: + attn_bias.masked_fill_(mask.logical_not(), float("-inf")) + else: + attn_bias = mask + attn_bias + + # Step 5: Softmax along the last dimension (seq_len dimension) + attn_weights = F.softmax(scores, dim=-1) # [batch, num_heads, seq_len, seq_len] + + # Step 6: Attention @ V + # [batch, num_heads, seq_len, seq_len] @ [batch, num_heads, seq_len, head_dim] + # -> [batch, num_heads, seq_len, head_dim] + attn_output = torch.matmul(attn_weights, v) + + # Step 7: Reshape back to [batch, seq_len, embed_dim] + attn_output = attn_output.transpose(1, 2) # [batch, seq_len, num_heads, head_dim] + attn_output = attn_output.contiguous().view(batch_size, seq_len, embed_dim) + + # Step 8: Output projection + attn_output_flat = attn_output.view(-1, embed_dim) + output = torch.nn.functional.linear(attn_output_flat, proj_weight, proj_bias) + output = output.view(batch_size, seq_len, embed_dim) + + if need_weights: + # Return attention weights: [batch, num_heads, seq_len, seq_len] -> [batch, seq_len, seq_len] + attn_weights_mean = attn_weights.mean(dim=1) # Average over heads + return output, attn_weights_mean + else: + return (output, None) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index bbc63b45..9c61c3d9 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -27,14 +27,14 @@ {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} {% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}>{% endif %} {{ kernel.def_local_vars(indent_size=2) }} affine.for %index0 = 0 to {{ M }} step {{ TILE_M }} { affine.for %index1 = 0 to {{ N }} step {{ TILE_N }} { {%- if Bias %} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N], indent_size=6) }} {%- else %} - affine.vector_store %v0, %Y_buffer[0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %Y_buffer[0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {%- endif %} affine.for %index2 = 0 to {{ K }} step {{ TILE_K }} { {% if prologue_nodes -%} @@ -77,16 +77,16 @@ {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} {% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} {{ kernel.def_local_vars(indent_size=2) }} affine.for %index1 = 0 to {{ N }} step {{ TILE_N }} { affine.for %index0 = 0 to {{ M }} step {{ TILE_M }} { - %Y_bufferT = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> + %Y_bufferT = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> {%- if Bias %} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N], indent_size=6) }} {%- else %} - affine.vector_store %v0, %Y_buffer[0, 0] : memref<{{ TILE_N }}x{{ TILE_M }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %Y_buffer[0, 0] : memref<{{ TILE_N }}x{{ TILE_M }}x{{DATA_STYPE}}, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {%- endif %} affine.for %index2 = 0 to {{ K }} step {{ TILE_K }} { {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_K], indent_size=8) }} @@ -105,6 +105,9 @@ class MLIRGemmTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None): super().__init__("kernel", input_nodes, layout, input_reorder) + self.support_epilogue_fusion = True + self.support_prologue_fusion = True + self.support_reduction_fusion = True def render(self, kernel: MLIRTemplateKernel, @@ -114,8 +117,9 @@ def render(self, tile_info = None, **kwargs): X, W, Y, M, N, K, n_epilogue_node, n_prologue_node, n_extra_read = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) + precision_bytes = mlir_common.get_dtype_nbytes(X.get_dtype()) if tile_info is None: - TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node)[0] + TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node, precision_bytes)[0] else: TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info @@ -154,7 +158,7 @@ def render(self, W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) W_tile_desc.set_name("W_buffer") W_tile_desc.offset = W.get_layout().offset - W_stride = W.get_layout().stride + W_stride = W.get_layout().stride if N>1 else [Y.get_layout().stride[0], 0] W_idx = [sympy.Symbol("index2") * W_stride[0], sympy.Symbol("index1") * W_stride[1]] vlane_split_axis = vlane_split_axis if nr_rdim==0 else 0 @@ -163,7 +167,7 @@ def render(self, Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) Y_tile_desc.set_name("Y_buffer") - Y_stride = Y.get_layout().stride + Y_stride = Y.get_layout().stride if N>1 else [Y.get_layout().stride[0], 0] if nr_rdim == 0: Y_idx = [sympy.Symbol("index0") * Y_stride[0], sympy.Symbol("index1") * Y_stride[1]] else: @@ -184,6 +188,8 @@ def render(self, else: Bias_idx = None + data_stype = mlir_common.DTYPE_TO_MLIR[X.get_dtype()] + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -194,7 +200,7 @@ def render(self, SUB_TILE_M=SUB_TILE_M, SUB_TILE_N=SUB_TILE_N, SUB_TILE_K=SUB_TILE_K, - DATA_STYPE="f32", + DATA_STYPE=data_stype, X = X, W = W, Y = Y, Bias = Bias, X_idx = X_idx, @@ -269,7 +275,8 @@ def get_tile_candidates(self, prologue_nodes: Optional[List[IRNode]] = None, **kwargs): X, W, Y, M, N, K, n_epilogue_node, n_prologue_node, n_extra_read = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) - return self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node) + precision_bytes = mlir_common.get_dtype_nbytes(X.get_dtype()) + return self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node, precision_bytes) def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): if template_buffer_node is not None: @@ -277,6 +284,12 @@ def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): # Extract input arguments info X, W, Y = self.input_nodes[0], self.input_nodes[1], self.output_node + dtype_infos = [("X", X.get_dtype()), ("W", W.get_dtype()), ("Y", Y.get_dtype())] + if len(self.input_nodes) > 2: + dtype_infos.append(("Bias", self.input_nodes[2].get_dtype())) + if len({dtype for _, dtype in dtype_infos}) != 1: + dtype_desc = ", ".join(f"{name}={dtype}" for name, dtype in dtype_infos) + raise NotImplementedError(f"Mixed dtype GEMM is not implemented yet ({dtype_desc})") X_tensor = empty_strided(X.layout.size, X.layout.stride) W_tensor = empty_strided(W.layout.size, W.layout.stride) if len(W_tensor.size()) > 2 or len(X_tensor.size()) > 2: @@ -296,7 +309,7 @@ def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): M, N, K = X_tensor.size()[0], W_tensor.size()[1], X_tensor.size()[1] return X,W,Y,M,N,K,n_epilogue_node,n_prologue_node,len(n_extra_read) - def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node): + def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node, precision_bytes): data = {} gemm_shape = f"{M}_{N}_{K}" if "external" in extension_config.codegen_mapping_strategy: @@ -316,7 +329,7 @@ def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_no else: # case 2: use heuristic mapping min_tile = (n_extra_node + n_prologue_node) == 0 - tile_candidates = kernel.gemm_combination_mapping(M, N, K, max(n_extra_read-2, 0), n_prologue_node, min_tile=True) + tile_candidates = kernel.gemm_combination_mapping(M, N, K, max(n_extra_read-2, 0), n_prologue_node, min_tile=True, precision_bytes=precision_bytes) # Edge case if (M == 0) or (N == 0) or (K == 0): diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index ebf0c80e..b717089f 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -1,3 +1,4 @@ +import math from typing import List, Optional, Sequence import torch @@ -15,10 +16,18 @@ from PyTorchSimFrontend.mlir.mlir_conv_sb_template import MLIRConvSingleBatchTemplate from PyTorchSimFrontend.mlir.mlir_conv_sbs_template import MLIRConvSingleBatchStridedTemplate from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate +from PyTorchSimFrontend.mlir.mlir_cat_template import MLIRCatTemplate +from PyTorchSimFrontend.mlir.mlir_sort_template import MLIRSortTemplate, MLIRStableSortTemplate +from PyTorchSimFrontend.mlir.mlir_sdpa_template import ( + MLIRFlashSDPATemplate, + flash_sdpa_args, + calculate_scale, +) from PyTorchSimFrontend import extension_config aten = torch.ops.aten aten_spmm = MLIRExternKernelChoice(torch.sparse.mm, "custom_op::sparse_addmm") +_orig_sort_values_stable_lowering = lowerings.get(aten.sort.values_stable) def tuned_mm(mat1, mat2, * ,layout=None): m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) @@ -38,6 +47,29 @@ def tuned_bmm(mat1, mat2, *, layout=None): return mlir_template.generate().output_node() + +def tuned_flash_sdpa( + query : TensorBox, + key : TensorBox, + value : TensorBox, + attn_bias : Optional[TensorBox] = None, + dropout_p : float = 0.0, + is_causal : bool = False, + return_debug_mask : bool = False, + scale : Optional[float] = None, + enable_gqa : bool = False) -> tuple: + # _fused_sdp_choice in C++ already guarantees: + # L == S (prefill), Hq == H (non-GQA), dropout_p == 0.0 + # before routing here via SDPBackend::overrideable. + # Non-matching shapes fall back to SDPBackend::math in C++ and decompose + # into primitive ops (matmul/softmax) before reaching this lowering. + scale = calculate_scale(query, scale) + N, Hq, H, L, S, E, Ev, layout, query, key, value = flash_sdpa_args(query, key, value) + mlir_template = MLIRFlashSDPATemplate([query, key, value], layout, scale) + return (mlir_template.generate().output_node(), None, None, None, None, None, None, None, None) + + + def conv_layout( x: TensorBox, weight: TensorBox, @@ -181,11 +213,105 @@ def custom_unsafe_index(x, indices): x.realize() return index_impl(x, indices, check=False) + +def _cat_layout(tensors: Sequence[TensorBox], dim: int) -> ir.Layout: + with V.graph.fake_mode: + output = torch.ops.aten.cat( + [ir.ir_node_to_tensor(t, guard_shape=True) for t in tensors], + dim, + ) + sizes = ir.convert_shape_to_inductor(output.size()) + stride = ir.convert_shape_to_inductor(output.stride()) + return ir.FixedLayout( + tensors[0].get_device(), + tensors[0].get_dtype(), + sizes, + stride, + ) + +def custom_cat_default(tensors: Sequence[TensorBox], dim: int = 0): + if tensors and dim < 0: + dim += len(tensors[0].get_size()) + copy_default_lowering = lowerings.get(aten.copy_.default) + empty_strided_lowering = lowerings.get(aten.empty_strided.default) + new_tensors = [] + for t in tensors: + t.realize() + # If the tensor is backed by a view (ReinterpretView, PermuteView, etc.), + # materialise it into a fresh contiguous FixedLayout buffer so the cat + # kernel always receives plain, dense strides. + if isinstance(t.data, ir.BaseView): + sizes = list(t.get_size()) + strides = [math.prod(sizes[i + 1:]) for i in range(len(sizes))] + new_buf = empty_strided_lowering( + sizes, strides, dtype=t.get_dtype(), device=t.get_device() + ) + tt = copy_default_lowering(new_buf, t) + else: + tt = t + new_tensors.append(tt) + + layout = _cat_layout(new_tensors, dim) + mlir_template = MLIRCatTemplate(list(new_tensors), layout, dim=dim) + return mlir_template.generate().output_node() + +def custom_sort_default( + value: TensorBox, + dim: int = -1, + descending: bool = False, + stable: Optional[bool] = None, +): + if dim < 0: + dim += len(value.get_size()) + + value.realize() + + value_layout, index_layout = _sort_layouts(value, dim, descending) + empty_strided_lowering = lowerings.get(aten.empty_strided.default) + indices = empty_strided_lowering( + value.get_size(), + index_layout.stride, + dtype=torch.int64, + device=value.get_device(), + ) + stable_required = True if stable is None else stable + sort_template_cls = MLIRStableSortTemplate if stable_required else MLIRSortTemplate + mlir_template = sort_template_cls( + [value, indices], + value_layout, + dim=dim, + descending=descending, + stable=stable_required, + ) + sorted_values = mlir_template.generate(template_buffer_node=value).output_node() + return sorted_values, indices + + +def _sort_layouts(x: TensorBox, dim: int, descending: bool): + with V.graph.fake_mode: + v, i = torch.ops.aten.sort( + ir.ir_node_to_tensor(x, guard_shape=True), + dim, + descending, + ) + v_sizes = ir.convert_shape_to_inductor(v.size()) + v_stride = ir.convert_shape_to_inductor(v.stride()) + i_sizes = ir.convert_shape_to_inductor(i.size()) + i_stride = ir.convert_shape_to_inductor(i.stride()) + + value_layout = ir.FixedLayout(x.get_device(), x.get_dtype(), v_sizes, v_stride) + index_layout = ir.FixedLayout(x.get_device(), torch.int64, i_sizes, i_stride) + return value_layout, index_layout + lowerings.update({getattr(aten.mm, overload): tuned_mm for overload in aten.mm.overloads()}) lowerings.update({getattr(aten.addmm, overload): tuned_addmm for overload in aten.addmm.overloads()}) lowerings.update({getattr(aten.convolution, overload): convolution for overload in aten.convolution.overloads()}) lowerings.update({getattr(aten.bmm, overload): tuned_bmm for overload in aten.bmm.overloads()}) lowerings.update({getattr(aten._sparse_addmm, overload): sparse_addmm for overload in aten._sparse_addmm.overloads()}) lowerings.update({getattr(aten._unsafe_index, overload): custom_unsafe_index for overload in aten._unsafe_index.overloads()}) +lowerings.update({getattr(aten.cat, overload): custom_cat_default for overload in aten.cat.overloads()}) +lowerings.update({getattr(aten.sort, overload): custom_sort_default for overload in aten.sort.overloads()}) + if extension_config.CONFIG_USE_TIMING_POOLING: - lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template \ No newline at end of file + lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template +lowerings.update({getattr(aten._scaled_dot_product_fused_attention_overrideable, overload): tuned_flash_sdpa for overload in aten._scaled_dot_product_fused_attention_overrideable.overloads()}) diff --git a/PyTorchSimFrontend/mlir/mlir_ops.py b/PyTorchSimFrontend/mlir/mlir_ops.py new file mode 100644 index 00000000..218f60a9 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_ops.py @@ -0,0 +1,1293 @@ +import math +import torch +import warnings + +from torch._inductor.codegen import common +from torch._inductor.virtualized import V, _ops as ops +from . import mlir_common + +warnings.filterwarnings('ignore', message='undefined OpHandler\\..*, please add missing op schema') + +def reduction_combine_vec(reduction_type, vector_value, init_value, axis, shape, reduced_shape): + if reduction_type == "sum": + return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" + if reduction_type == "prod": + return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" + if reduction_type == "max": + return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" + if reduction_type == "min": + return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" + if reduction_type == "any": + return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" + raise AssertionError(reduction_type) + +def format_mlir_op(op_str, shape, **kwargs): + """ + Format MLIR operation string with optional attributes and comment. + + Args: + op_str: Base operation string (e.g., "arith.addi %0, %1") + shape: Type shape string (e.g., "vector<4xi64>" or "i64") + **kwargs: May contain 'attributes' (dict or str) and 'comment' (str) + + Returns: + Formatted MLIR operation string + """ + result = op_str + attributes = kwargs.get('attributes', None) + comment = kwargs.get('comment', None) + + if attributes: + if isinstance(attributes, dict): + # Format: { key1=value1, key2=value2 } + attrs_str = ", ".join(f"{k}={v}" for k, v in attributes.items()) + result += f" {{ {attrs_str} }}" + elif isinstance(attributes, str): + # Direct string format + result += f" {{ {attributes} }}" + result += f" : {shape}" + if comment: + result += f" // {comment}" + return result + +class ExtensionOverrides(common.OpOverrides): + @staticmethod + def constant(value, src_type, *args, **kwargs): + if isinstance(src_type, torch.dtype): + src_type = mlir_common.DTYPE_TO_MLIR[src_type] + + str_val = str(value) + if "inf" == str_val or "-inf" == str_val or "nan" == str_val: + value = f"0x{mlir_common.MLIR_INF[str_val][src_type]:x}" + elif isinstance(value, bool): + value = 1 if value else 0 + if src_type[0] == "f": + value = format(float(value), ".20f") + # scientific notation check + elif "e" in str_val: + value = format(float(value), ".20f") + elif src_type[0] == "f": + value = format(float(value), ".20f") + elif src_type[0] == "i": + value = int(float(value)) + return format_mlir_op(f'arith.constant {value}', src_type, **kwargs), [1, src_type] + + @staticmethod + def broadcast(operand, target_size, *args, **kwargs): + src_size, dtype = V.kernel.var_info[operand] + + src_shape = f"vector<{src_size}x{dtype}>" if src_size > 1 else dtype + dst_shape = f"vector<{target_size}x{dtype}>" + + op_str = "" + # Special case for length 2 vector. We used this vector to avoid scalar operations... + if src_size > 1: + if target_size % src_size == 0: + unflat_operand = ops.broadcast_unflat(operand, target_size) + outer_dim = target_size // src_size + unflat_shape = f"vector<{outer_dim}x{src_size}x{dtype}>" + # Flatten back to 1D + op_str = f"vector.shape_cast %{unflat_operand}" + shape = f"{unflat_shape} to {dst_shape}" + else: + raise NotImplementedError( + f"Vector broadcast size mismatch: src={src_size} cannot broadcast to target={target_size}" + ) + elif src_size == 1: + op_str = f"vector.broadcast %{operand}" + shape = f"{src_shape} to {dst_shape}" + else: + raise ValueError(f"Invalid source size: {src_size}") + return format_mlir_op(op_str, shape, **kwargs), [target_size, dtype] + + @staticmethod + def broadcast_unflat(operand, target_size, *args, **kwargs): + src_size, dtype = V.kernel.var_info[operand] + + outer_dim = target_size // src_size + src_shape = f"vector<{src_size}x{dtype}>" + dst_shape = f"vector<{outer_dim}x{src_size}x{dtype}>" + + op_str = f"vector.broadcast %{operand}" + shape = f"{src_shape} to {dst_shape}" + return format_mlir_op(op_str, shape, **kwargs), [target_size, dtype] + + def load_seed(self, *args, **kwargs): + raise NotImplementedError + + def rand(self, *args, **kwargs): + raise NotImplementedError + + def randn(self, *args, **kwargs): + raise NotImplementedError + + def randint64(self, *args, **kwargs): + raise NotImplementedError + + # Special operaitons + @staticmethod + def masked(mask, body, other, *args, tile_size=16, dtype="f32", ninf_declared=False, **kwargs): + result = body() + val = ops.constant(other, dtype, *args, **kwargs) + result = ops.where(mask, result, val) + return result, V.kernel.var_info[result] + + @staticmethod + def where(condition, operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + cond_type = V.kernel.var_info[condition] + operand_type = V.kernel.var_info[operand1] + condition = ops.to_bool(condition) + if cond_type[0] < tile_size: + condition = ops.broadcast(condition, tile_size) + elif cond_type[0] > tile_size: + operand1 = ops.broadcast(operand1, cond_type[0]) + operand2 = ops.broadcast(operand2, cond_type[0]) + tile_size, ret_type = V.kernel.var_info[operand1] + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + cond_shape = f"vector<{tile_size}xi1>" if tile_size > 1 else "" + + op_str = f"arith.select %{condition}, %{operand1}, %{operand2}" + shape = f"{cond_shape}, {shape}" + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] + + @staticmethod + def to_dtype(operand, dst_mlir_dtype, *args, **kwargs): + # Extract source information + src_mlir_dtype = V.kernel.var_info[operand][1] + tile_size = V.kernel.var_info[operand][0] + + # Normalize destination type (Torch dtype -> MLIR string) + if isinstance(dst_mlir_dtype, torch.dtype): + dst_mlir_dtype = mlir_common.DTYPE_TO_MLIR[dst_mlir_dtype] + + if src_mlir_dtype == "index" and dst_mlir_dtype != "index": + operand = ops.index_cast(operand, "i64") + src_mlir_dtype = "i64" # Update explicitly + + if dst_mlir_dtype == "index": + # If source is already index, return as is; otherwise cast + if src_mlir_dtype == "index": + return operand, [tile_size, "index"] + return ops.index_cast(operand, "index"), [tile_size, "index"] + + # Early return if types are identical + if src_mlir_dtype == dst_mlir_dtype: + return operand, [tile_size, dst_mlir_dtype] + + dst_bits = mlir_common.MLIR_TO_BIT[dst_mlir_dtype] + src_bits = mlir_common.MLIR_TO_BIT[src_mlir_dtype] + shape = f"vector<{tile_size}x{dst_mlir_dtype}>" if tile_size > 1 else dst_mlir_dtype + src_shape = f"vector<{tile_size}x{src_mlir_dtype}>" if tile_size > 1 else src_mlir_dtype + src_type_char = src_mlir_dtype[0] # 'i' or 'f' + dst_type_char = dst_mlir_dtype[0] # 'i' or 'f'o + + op_str = "" + + # Case A: Integer -> Float + if src_type_char == "i" and dst_type_char == "f": + op_str = f"arith.uitofp %{operand} : {src_shape} to {shape}" + # Case B: Float -> Integer + elif src_type_char == "f" and dst_type_char == "i": + op_str = f"arith.fptosi %{operand} : {src_shape} to {shape}" + # Case C: Integer -> Integer (Extension / Truncation) + elif src_type_char == "i" and dst_type_char == "i": + if dst_bits > src_bits: + op_str = f"arith.extsi %{operand} : {src_shape} to {shape}" + elif dst_bits < src_bits: + # Use arith.trunci for integer truncation + op_str = f"arith.trunci %{operand} : {src_shape} to {shape}" + else: + return operand, [tile_size, dst_mlir_dtype] + # Case D: Float -> Float (Extension / Truncation) + elif src_type_char == "f" and dst_type_char == "f": + if dst_bits > src_bits: + op_str = f"arith.extf %{operand} : {src_shape} to {shape}" + elif dst_bits < src_bits: + # Corrected 'trunf' to 'truncf' + op_str = f"arith.truncf %{operand} : {src_shape} to {shape}" + else: + return operand, [tile_size, dst_mlir_dtype] + else: + raise NotImplementedError(f"Unsupported conversion: {src_mlir_dtype} -> {dst_mlir_dtype}") + + return op_str, [tile_size, dst_mlir_dtype] + + @staticmethod + def identity(operand, *args, **kwargs): + operand_info = V.kernel.var_info[operand] + return operand, operand_info + + @staticmethod + def to_dtype_bitcast(operand, dtype, *args, **kwargs): + tile_size, current_src_type = V.kernel.var_info[operand] + + if isinstance(dtype, torch.dtype): + dst_mlir_type = mlir_common.DTYPE_TO_MLIR[dtype] + else: + dst_mlir_type = dtype + + src_bits = mlir_common.MLIR_TO_BIT[current_src_type] + dst_bits = mlir_common.MLIR_TO_BIT[dst_mlir_type] + + if src_bits != dst_bits: + raise ValueError( + f"Bitcast failed: Bit width mismatch. " + f"Src: {current_src_type}({src_bits}b) != Dst: {dst_mlir_type}({dst_bits}b)" + ) + + src_shape = f"vector<{tile_size}x{current_src_type}>" if tile_size > 1 else current_src_type + dst_shape = f"vector<{tile_size}x{dst_mlir_type}>" if tile_size > 1 else dst_mlir_type + + op_str = f"arith.bitcast %{operand}" + shape = f"{src_shape} to {dst_shape}" + return format_mlir_op(op_str, shape, **kwargs), [tile_size, dst_mlir_type] + + # Binary element wise operations + @staticmethod + def binary_elementwise_common(operand1, operand2): + V.kernel.var_info = V.kernel.var_info + operand1.bounds = operand1.bounds.unknown() + operand2.bounds = operand2.bounds.unknown() + op_type1 = V.kernel.var_info[operand1] + op_type2 = V.kernel.var_info[operand2] + # Tile size check + if op_type1[0] != op_type2[0]: + # Try to broad cast + lhs_tile_size, lhs_dtype = op_type1 + rhs_tile_size, rhs_dtype = op_type2 + if lhs_tile_size > rhs_tile_size: + operand2 = ops.broadcast(operand2, lhs_tile_size) + op_type2 = V.kernel.var_info[operand2] + elif lhs_tile_size < rhs_tile_size: + operand1 = ops.broadcast(operand1, rhs_tile_size) + op_type1 = V.kernel.var_info[operand1] + + # Data type check + if op_type1[1] != op_type2[1]: + if op_type1[1] == "index" or op_type1 == "index": + if op_type1[1] == "index": + # index -> target type: 2-step casting if target is float + if op_type2[1][0] == "f": + operand1 = ops.index_cast(operand1, "i64") + operand1 = ops.to_dtype(operand1, op_type2[1]) + op_type1 = V.kernel.var_info[operand1] + else: + # index -> integer: direct casting + operand1 = ops.index_cast(operand1, op_type2[1]) + op_type1 = V.kernel.var_info[operand1] + if op_type2[1] == "index": + # index -> target type: 2-step casting if target is float + if op_type1[1][0] == "f": + operand2 = ops.index_cast(operand2, "i64") + operand2 = ops.to_dtype(operand2, op_type1[1]) + op_type2 = V.kernel.var_info[operand2] + else: + # index -> integer: direct casting + operand2 = ops.index_cast(operand2, op_type1[1]) + op_type2 = V.kernel.var_info[operand2] + elif op_type1[1][0] == "i" and op_type2[1][0] == "f": + operand1 = ops.to_dtype(operand1, op_type2[1]) + op_type1 = V.kernel.var_info[operand1] + elif op_type1[1][0] == "f" and op_type2[1][0] == "i": + operand2 = ops.to_dtype(operand2, op_type1[1]) + op_type2 = V.kernel.var_info[operand2] + elif op_type1[1][0] == op_type2[1][0]: + if mlir_common.MLIR_TO_BIT[op_type1[1]] > mlir_common.MLIR_TO_BIT[op_type2[1]]: + operand2 = ops.ext(operand2, op_type1[1]) + op_type2 = V.kernel.var_info[operand2] + elif mlir_common.MLIR_TO_BIT[op_type1[1]] < mlir_common.MLIR_TO_BIT[op_type2[1]]: + operand1 = ops.ext(operand1, op_type2[1]) + op_type1 = V.kernel.var_info[operand1] + else: + raise NotImplementedError("Unsupported type converting") + + # Updated var info + tile_size = op_type1[0] + ret_type = op_type1[1] + return tile_size, ret_type, operand1, operand2 + + @staticmethod + def abs(operand, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def exp(operand, *args, **kwargs): + # Check scalar + op_type = V.kernel.var_info[operand] + if op_type[0] == 1: + operand = ops.broadcast(operand, 4) + val = ops.exp(operand) + result = ops.extractelement(val, 0) + return result, V.kernel.var_info[result] + op_type = V.kernel.var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return format_mlir_op(f'math.exp %{operand}', shape, **kwargs), [tile_size, dtype] + + @staticmethod + def exp2(operand, *args, **kwargs): + # Hands-on part: implement exp2 using math.exp2 + # V.kernel.var_info = {operand: [tile_size, dtype]} + # Ex) V.kernel.var_info[operand] = [8, "f32"] + + ln2 = math.log(2) + coeff = ops.constant(ln2, "f32") + operand = ops.mul(operand, coeff) + return ops.exp(operand), V.kernel.var_info[operand] + + @staticmethod + def expm1(operand, *args, **kwargs): + coeff = ops.constant(1.0, "f32") + operand = ops.exp(operand) + operand = ops.sub(operand, coeff) + return operand, V.kernel.var_info[operand] + + @staticmethod + def sqrt(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] + + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype.startswith("f"): + operand = ops.to_dtype(operand, "f32") + + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return format_mlir_op(f'math.sqrt %{operand}', shape, **kwargs), [tile_size, dtype] + + @staticmethod + def relu(operand, *args, **kwargs): + src_mlir_dtype = V.kernel.var_info[operand][1] + tile_size = V.kernel.var_info[operand][0] + return ops.maximum(operand, ops.constant(0, src_mlir_dtype)), [tile_size, src_mlir_dtype] + + @staticmethod + def minimum(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + if ret_type[0] == "f": + opcode = f'arith.minimumf' + else: + opcode = f'arith.minsi' + op_str = f'{opcode} %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] + + @staticmethod + def maximum(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + if ret_type[0] == "f": + opcode = f'arith.maximumf' + else: + opcode = f'arith.maxsi' + op_str = f'{opcode} %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] + + @staticmethod + def cos(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] + + # Check scalar + op_type = V.kernel.var_info[operand] + if op_type[0] == 1: + operand = ops.broadcast(operand, 4) + val = ops.cos(operand) + result = ops.extractelement(val, 0) + return result, V.kernel.var_info[result] + op_type = V.kernel.var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype.startswith("f"): + operand = ops.to_dtype(operand, "f32") + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return format_mlir_op(f'math.cos %{operand}', shape, **kwargs), [tile_size, dtype] + + @staticmethod + def sin(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] + + # Check scalar + op_type = V.kernel.var_info[operand] + if op_type[0] == 1: + operand = ops.broadcast(operand, 4) + val = ops.sin(operand) + result = ops.extractelement(val, 0) + return result, V.kernel.var_info[result] + op_type = V.kernel.var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype.startswith("f"): + operand = ops.to_dtype(operand, "f32") + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return format_mlir_op(f'math.sin %{operand}', shape, **kwargs), [tile_size, dtype] + + @staticmethod + def tan(operand, *args, **kwargs): + sin_res = ops.sin(operand) + cos_res = ops.cos(operand) + operand = ops.truediv(sin_res, cos_res) + return operand, V.kernel.var_info[operand] + + @staticmethod + def lgamma(operand, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def erf(operand, *args, **kwargs): + # Check scalar + op_type = V.kernel.var_info[operand] + if op_type[0] == 1: + operand = ops.broadcast(operand, 4) + val = ops.erf(operand) + result = ops.extractelement(val, 0) + return result, V.kernel.var_info[result] + op_type = V.kernel.var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return format_mlir_op(f'math.erf %{operand}', shape, **kwargs), [tile_size, dtype] + + @staticmethod + def cosh(operand, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def sinh(operand, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def tanh(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] + + # Check scalar + op_type = V.kernel.var_info[operand] + if op_type[0] == 1: + operand = ops.broadcast(operand, 4) + val = ops.tanh(operand) + result = ops.extractelement(val, 0) + return result, V.kernel.var_info[result] + op_type = V.kernel.var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype.startswith("f"): + operand = ops.to_dtype(operand, "f32") + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return format_mlir_op(f'math.tanh %{operand}', shape, **kwargs), [tile_size, dtype] + + @staticmethod + def acos(operand, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def acosh(operand, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def asin(operand, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def asinh(operand, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def atan2(operand1, operand2, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def atan(operand, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def atanh(operand, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def copysign(operand1, operand2, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def erfc(operand, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def erfinv(operand, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def frexp(operand, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def hypot(operand1, operand2, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def log10(operand, *args, **kwargs): + val_ln = ops.log(operand) + + tile_size, dtype = V.kernel.var_info[val_ln] + inv_ln10 = 1/math.log(10) + const_op = ops.constant(inv_ln10, dtype) + + # Multiply: ln(x) * (1/ln(10)) + result = ops.mul(val_ln, const_op) + return result, V.kernel.var_info[result] + + @staticmethod + def log2(operand, *args, **kwargs): + val_ln = ops.log(operand) + tile_size, dtype = V.kernel.var_info[val_ln] + inv_ln10 = 1/math.log(2) + const_op = ops.constant(inv_ln10, dtype) + + # Multiply: ln(x) * (1/ln(10)) + result = ops.mul(val_ln, const_op) + return result, V.kernel.var_info[result] + + @staticmethod + def log(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype.startswith("f"): + operand = ops.to_dtype(operand, "f32") + + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return format_mlir_op(f'math.log %{operand}', shape, **kwargs), [tile_size, dtype] + + @staticmethod + def log1p(operand, *args, **kwargs): + tile_size, dtype = V.kernel.var_info[operand] + const_one = ops.constant(1, dtype) + + val_add = ops.add(operand, const_one) + result = ops.log(val_add) + return result, V.kernel.var_info[result] + + @staticmethod + def nextafter(operand1, operand2, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def logical_and(operand1, operand2, *args, **kwargs): + if V.kernel.var_info[operand1][1] != "i1": + operand1 = ops.to_bool(operand1) + if V.kernel.var_info[operand2][1] != "i1": + operand2 = ops.to_bool(operand2) + result = ops.and_(operand1, operand2) + return result, V.kernel.var_info[result] + + @staticmethod + def logical_or(operand1, operand2, *args, **kwargs): + if V.kernel.var_info[operand1][1] != "i1": + operand1 = ops.to_bool(operand1) + if V.kernel.var_info[operand2][1] != "i1": + operand2 = ops.to_bool(operand2) + result = ops.or_(operand1, operand2) + return result, V.kernel.var_info[result] + + @staticmethod + def logical_xor(operand1, operand2, *args, **kwargs): + if V.kernel.var_info[operand1][1] != "i1": + operand1 = ops.to_bool(operand1) + if V.kernel.var_info[operand2][1] != "i1": + operand2 = ops.to_bool(operand2) + result = ops.xor(operand1, operand2) + return result, V.kernel.var_info[result] + + @staticmethod + def logical_not(operand, *args, **kwargs): + op_info = V.kernel.var_info[operand] + tile_size = op_info[0] + dtype = op_info[1] + zero_const = ops.constant(0, dtype) + result = ops.eq(operand, zero_const) + return result, V.kernel.var_info[result] + + @staticmethod + def bitwise_and(operand1, operand2, *args, **kwargs): + # Float check + if V.kernel.var_info[operand1][1].startswith("f") or V.kernel.var_info[operand2][1].startswith("f"): + raise ValueError("Bitwise AND not supported for floats") + result = ops.and_(operand1, operand2) + return result, V.kernel.var_info[result] + + @staticmethod + def bitwise_not(operand, *args, **kwargs): + tile_size, dtype = V.kernel.var_info[operand] + # Float check + if V.kernel.var_info[operand][1].startswith("f"): + raise ValueError("Bitwise NOT not supported for floats") + neg_one = ops.constant(-1, dtype) + result = ops.xor(operand, neg_one) + return result, V.kernel.var_info[result] + + @staticmethod + def bitwise_or(operand1, operand2, *args, **kwargs): + # Float check + if V.kernel.var_info[operand1][1].startswith("f") or V.kernel.var_info[operand2][1].startswith("f"): + raise ValueError("Bitwise AND not supported for floats") + + result = ops.or_(operand1, operand2) + return result, V.kernel.var_info[result] + + @staticmethod + def bitwise_xor(operand1, operand2, *args, **kwargs): + # Float check + if V.kernel.var_info[operand1][1].startswith("f") or V.kernel.var_info[operand2][1].startswith("f"): + raise ValueError("Bitwise AND not supported for floats") + result = ops.xor(operand1, operand2) + return result, V.kernel.var_info[result] + + @staticmethod + def bitwise_left_shift(operand1, operand2, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def bitwise_right_shift(operand1, operand2, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def rsqrt(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype.startswith("f"): + operand = ops.to_dtype(operand, "f32") + + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return format_mlir_op(f'math.rsqrt %{operand}', shape, **kwargs), [tile_size, dtype] + + @staticmethod + def sigmoid(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + one = ops.constant(1, dtype) + return ops.truediv(one, ops.add(one, ops.exp(ops.neg(operand)))), [tile_size, dtype] + + @staticmethod + def fmod(operand1, operand2, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def isinf(operand, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def isnan(operand, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def round(operand, *args, **kwargs): + tile_size, dtype = V.kernel.var_info[operand] + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + + if dtype.startswith("f"): + op_str = f"math.roundeven %{operand}" + return format_mlir_op(op_str, shape, **kwargs), [tile_size, dtype] + else: + return operand, [tile_size, dtype] + + @staticmethod + def floor(operand, *args, **kwargs): + tile_size, dtype = V.kernel.var_info[operand] + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + + if dtype.startswith("f"): + op_str = f"math.floor %{operand}" + return format_mlir_op(op_str, shape, **kwargs), [tile_size, dtype] + else: + return operand, [tile_size, dtype] + + @staticmethod + def sign(operand, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def trunc(operand, *args, **kwargs): + tile_size, dtype = V.kernel.var_info[operand] + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + + if dtype.startswith("f"): + op_str = f"math.trunc %{operand}" + return format_mlir_op(op_str, shape, **kwargs), [tile_size, dtype] + else: + return operand, [tile_size, dtype] + + @staticmethod + def ceil(operand, *args, **kwargs): + tile_size, dtype = V.kernel.var_info[operand] + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + + if dtype.startswith("f"): + op_str = f"math.ceil %{operand}" + return format_mlir_op(op_str, shape, **kwargs), [tile_size, dtype] + else: + return operand, [tile_size, dtype] + + # Logical operations + @staticmethod + def neg(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype.startswith("f"): + operand = ops.to_dtype(operand, "f32") + op_str = f"arith.negf %{operand}" + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return format_mlir_op(op_str, shape, **kwargs), [tile_size, dtype] + + @staticmethod + def reciprocal(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] + tile_size, dtype = op_type[0], op_type[1] + if dtype.startswith("i"): + openand = ops.to_dtype(operand, "f32") + op_type = V.kernel.var_info[operand] + tile_size, dtype = op_type[0], op_type[1] + + return ops.truediv(ops.constant(1.0, dtype), operand), [tile_size, dtype] + + @staticmethod + def eq(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "oeq" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "eq" + else: + raise ValueError(f"Unsupported data type for 'eq' operation: {ret_type}") + + op_str = f'{op_type} {attribute}, %{operand1}, %{operand2}' + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return format_mlir_op(op_str, shape, **kwargs), [tile_size, "i1"] + + @staticmethod + def ne(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "one" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "ne" + else: + raise ValueError(f"Unsupported data type for 'ne' operation: {ret_type}") + + op_str = f'{op_type} {attribute}, %{operand1}, %{operand2}' + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return format_mlir_op(op_str, shape, **kwargs), [tile_size, "i1"] + + @staticmethod + def lt(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "olt" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "slt" + else: + raise ValueError(f"Unsupported data type for 'lt' operation: {ret_type}") + + op_str = f'{op_type} {attribute}, %{operand1}, %{operand2}' + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return format_mlir_op(op_str, shape, **kwargs), [tile_size, "i1"] + + @staticmethod + def gt(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "ogt" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "sgt" + else: + raise ValueError(f"Unsupported data type for 'gt' operation: {ret_type}") + + op_str = f'{op_type} {attribute}, %{operand1}, %{operand2}' + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return format_mlir_op(op_str, shape, **kwargs), [tile_size, "i1"] + + @staticmethod + def le(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "ole" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "sle" + else: + raise ValueError(f"Unsupported data type for 'le' operation: {ret_type}") + + op_str = f'{op_type} {attribute}, %{operand1}, %{operand2}' + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return format_mlir_op(op_str, shape, **kwargs), [tile_size, "i1"] + + @staticmethod + def ge(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "oge" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "sge" + else: + raise ValueError(f"Unsupported data type for 'ne' operation: {ret_type}") + + op_str = f'{op_type} {attribute}, %{operand1}, %{operand2}' + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return format_mlir_op(op_str, shape, **kwargs), [tile_size, "i1"] + + @staticmethod + def add(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + opcode = f'arith.add{ret_type[0]}' + op_str = f'{opcode} %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] + + @staticmethod + def sub(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + opcode = f'arith.sub{ret_type[0]}' + op_str = f'{opcode} %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] + + @staticmethod + def mul(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + opcode = f'arith.mul{ret_type[0]}' + op_str = f'{opcode} %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] + + @staticmethod + def pow(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + # Type check & auto cast + if ret_type.startswith("f"): + operand1 = ops.to_dtype(operand1, "f32") + + # Type check & auto cast + if ret_type.startswith("f"): + operand2 = ops.to_dtype(operand2, "f32") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + op_str = f"math.pow{ret_type[0]} %{operand1}, %{operand2}" + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] + + @staticmethod + def and_(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + op_str = f'arith.andi %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] + + @staticmethod + def or_(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + op_str = f'arith.ori %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] + + @staticmethod + def xor(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + op_str = f'arith.xori %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] + + @staticmethod + def lshift(operand1, operand2, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def rshift(operand1, operand2, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def truncdiv(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + + if ret_type.startswith("f"): + raise ValueError("truncdiv is strictly for integers. Use truediv for floats.") + + # arith.divsi: Signed Integer Division (Result is truncated) + op_str = f'arith.divsi %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] + + @staticmethod + def floordiv(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + + if ret_type.startswith("f"): + # Float의 floor division은 보통 divf 후 floor를 하므로 여기선 정수만 처리 + raise ValueError("floordiv implementation expects integers based on definition.") + + # arith.floordivsi: Floor Division for Signed Integers + op_str = f'arith.floordivsi %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] + + @staticmethod + def truediv(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + + if not ret_type.startswith("f"): + raise ValueError(f"truediv expects float inputs, but got {ret_type}. Use int_truediv for integers.") + + op_str = f'arith.divf %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] + + @staticmethod + def int_truediv(operand1, operand2, *args, **kwargs): + """ + True division for Integers (Int -> Float). + Promotes integers to floats, then performs floating-point division. + """ + tile_size, src_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + if not src_type.startswith("f"): + target_float_type = "f32" + operand1 = ops.to_dtype(operand1, target_float_type) + operand2 = ops.to_dtype(operand2, target_float_type) + src_type = target_float_type + + result = ops.truediv(operand1, operand2) + return result, V.kernel.var_info[result] + + @staticmethod + def mod(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + if ret_type[0] == "f": + raise NotImplementedError("Not support remainder operation for floating point") + else: + opcode = f'arith.remsi' + op_str = f'{opcode} %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] + + @staticmethod + def remainder(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + + if ret_type.startswith("f"): + opcode = 'arith.remf' + else: + opcode = 'arith.remsi' # Signed Integer Remainder (LHS sign) + + op_str = f'{opcode} %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] + + @staticmethod + def square(operand, *args, **kwargs): + result = ops.mul(operand, operand) + return result, V.kernel.var_info[result] + + @staticmethod + def fma(operand1, operand2, operand3, *args, **kwargs): + result = ops.mul(operand1, operand2) + result = ops.add(result, operand3) + return result, V.kernel.var_info[result] + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # PyTorchSim specific operations + + @staticmethod + def alloc(size, src_type, *args, **kwargs): + return f"memref.alloc() : memref<{size}x{src_type}>", [size, src_type] + + @staticmethod + def extractelement(operand, idx, *args, **kwargs): + op_type = V.kernel.var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + op_str = f"vector.extract %{operand}[{idx}]" + shape = f"{dtype} from {shape}" + return format_mlir_op(op_str, shape, **kwargs), [1, dtype] + + @staticmethod + def ext(operand, dtype, *args, **kwargs): + op_type = V.kernel.var_info[operand] + shape = f"vector<{op_type[0]}x{op_type[1]}>" if op_type[0] > 1 else f"{op_type[1]}" + target_type = f"vector<{op_type[0]}x{dtype}>" if op_type[0] > 1 else f"{dtype}" + if dtype[0] == "f": + opcode = f'arith.extf' + else: + opcode = f'arith.extui' + op_str = f'{opcode} %{operand}' + shape = f"{shape} to {target_type}" + return format_mlir_op(op_str, shape, **kwargs), [op_type[0], dtype] + + @staticmethod + def to_bool(operand, *args, **kwargs): + tile_size, ret_type = V.kernel.var_info[operand] + if ret_type == "i1": + return operand, [tile_size, ret_type] + + const_zero = ops.constant(0, ret_type) + if tile_size > 1: + const_zero = ops.broadcast(const_zero, tile_size) + ret = ops.ne(operand, const_zero) + return ret, [tile_size, "i1"] + @staticmethod + def step(size, dtype, *args, **kwargs): + index_shape = f"vector<{size}x{dtype}>" + op_str = f"vector.step" + return format_mlir_op(op_str, index_shape, **kwargs), [size, dtype] + + @staticmethod + def index_cast(operand, target_type, *args, **kwargs): + op_type = V.kernel.var_info[operand] + src_shape = f"vector<{op_type[0]}x{op_type[1]}>" if op_type[0] > 1 else op_type[1] + des_shape = f"vector<{op_type[0]}x{target_type}>" if op_type[0] > 1 else target_type + op_str = f"arith.index_cast %{operand}" + shape = f"{src_shape} to {des_shape}" + return format_mlir_op(op_str, shape, **kwargs), [op_type[0], target_type] + + @staticmethod + def shape_cast(operand, src_shape, dst_shape, *args, **kwargs): + operand_type = V.kernel.var_info[operand] + op_str = f"vector.shape_cast %{operand}" + shape = f"{src_shape} to {dst_shape}" + return format_mlir_op(op_str, shape, **kwargs), operand_type + + @staticmethod + def extract_strided_slice(operand, target_size, offsets=None, sizes=None, strides=None, *args, **kwargs): + op_type = V.kernel.var_info[operand] + src_size = op_type[0] + dtype = op_type[1] + + if offsets is None: + offsets = [0] + if sizes is None: + sizes = [target_size] + if strides is None: + strides = [1] + + src_shape = f"vector<{src_size}x{dtype}>" + dst_shape = f"vector<{target_size}x{dtype}>" + + offsets_str = ", ".join(str(o) for o in offsets) + sizes_str = ", ".join(str(s) for s in sizes) + strides_str = ", ".join(str(s) for s in strides) + + # Build attributes dict for offsets, sizes, strides + built_attributes = { + "offsets": f"[{offsets_str}]", + "sizes": f"[{sizes_str}]", + "strides": f"[{strides_str}]" + } + + # Merge with any existing attributes from kwargs + existing_attributes = kwargs.get('attributes', {}) + if isinstance(existing_attributes, dict): + merged_attributes = {**built_attributes, **existing_attributes} + elif isinstance(existing_attributes, str): + built_attrs_str = ", ".join(f"{k}={v}" for k, v in built_attributes.items()) + merged_attributes = f"{built_attrs_str}, {existing_attributes}" + else: + merged_attributes = built_attributes + + op_str = f"vector.extract_strided_slice %{operand}" + shape = f"{src_shape} to {dst_shape}" + + # Pass merged attributes to format_mlir_op + updated_kwargs = {**kwargs, 'attributes': merged_attributes} + return format_mlir_op(op_str, shape, **updated_kwargs), [target_size, dtype] + + @staticmethod + def vlane_offset(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + opcode = f'arith.add{ret_type[0]}' + op_str = f'{opcode} %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] + + @staticmethod + def multi_reduction(acc, init, vec_size, red_size, red_shape, red_type, type_name, *args, **kwargs): + if red_size == 1: + final_reduced_shape = f"{type_name}" + line = reduction_combine_vec(red_type, acc, init, axis=0, shape=red_shape, reduced_shape=final_reduced_shape) + else: + final_reduced_shape = f"vector<{red_size}x{type_name}>" + new_vshape= f"vector<{vec_size//red_size}x{red_size}x{type_name}>" + value = ops.shape_cast(acc, red_shape, new_vshape) + line = reduction_combine_vec(red_type, value, init, axis=0, shape=new_vshape, reduced_shape=final_reduced_shape) + return line, [red_size, type_name] + + @staticmethod + def vector_shuffle(operand, indices, operand2=None, *args, **kwargs): + tile_size1, dtype1 = V.kernel.var_info[operand] + if operand2 is None: + operand2 = operand + tile_size2, dtype2 = V.kernel.var_info[operand2] + if dtype1 != dtype2: + raise ValueError( + f"vector_shuffle expects same element type, got {dtype1} and {dtype2}" + ) + total_size = tile_size1 + tile_size2 + for idx in indices: + if idx < -1 or idx >= total_size: + raise ValueError( + f"vector_shuffle index out of range: {idx}, expected in [-1, {total_size - 1}]" + ) + vt1 = f"vector<{tile_size1}x{dtype1}>" + vt2 = f"vector<{tile_size2}x{dtype1}>" + idx_str = ", ".join(str(i) for i in indices) + op_str = f"vector.shuffle %{operand}, %{operand2} [{idx_str}]" + return format_mlir_op(op_str, f"{vt1}, {vt2}", **kwargs), [len(indices), dtype1] + + @staticmethod + def constant_mask(select_min, N, *args, **kwargs): + vals = ", ".join("true" if x else "false" for x in select_min) + op_str = f"arith.constant dense<[{vals}]>" + return format_mlir_op(op_str, f"vector<{N}xi1>", **kwargs), [N, "i1"] + + @staticmethod + def bitonic_sort(operand, descending=False, *args, **kwargs): + def _compute_bitonic_stages(N: int, descending: bool): + assert N >= 2 and (N & (N - 1)) == 0, "N must be power-of-2 >= 2" + stages = [] + size = 2 + while size <= N: + stride = size // 2 + while stride >= 1: + merged_shuffle = list(range(N)) + merged_mask = [None] * N + + for start in range(0, N, size): + blk_dir = "ASCENDING" if (start // size) % 2 == 0 else "DESCENDING" + for i in range(start, start + size - stride, stride * 2): + for j in range(stride): + a, b = i + j, i + j + stride + merged_shuffle[a] = b + merged_shuffle[b] = a + if blk_dir == "ASCENDING": + merged_mask[a] = True # a = min + merged_mask[b] = False # b = max + else: + merged_mask[a] = False # a = max + merged_mask[b] = True # b = min + select_min = [bool(x) if x is not None else False for x in merged_mask] + if descending: + select_min = [not x for x in select_min] + stages.append({ + "shuffle": merged_shuffle, + "select_min": select_min, + }) + stride //= 2 + size *= 2 + return stages + + tile_size, _ = V.kernel.var_info[operand] + cur = operand + for stage in _compute_bitonic_stages(tile_size, descending): + mask = ops.constant_mask(stage["select_min"], tile_size) + shuffled = ops.vector_shuffle(cur, stage["shuffle"]) + vmin = ops.minimum(cur, shuffled) + vmax = ops.maximum(cur, shuffled) + cur = ops.where(mask, vmin, vmax) + return cur, V.kernel.var_info[cur] + + @staticmethod + def _load(compute_vec_size, mlir_dtype, buffer, indices, buffer_shape, *args, **kwargs): + if compute_vec_size == 1: + vshape = f"{mlir_dtype}" + operation = "affine.load" + line = f"{operation} %{buffer}[{indices}]" + shape = buffer_shape + else: + vshape = f"vector<{compute_vec_size}x{mlir_dtype}>" + operation = "affine.vector_load" + line = f"{operation} %{buffer}[{indices}]" + shape = f"{buffer_shape}, {vshape}" + return format_mlir_op(line, shape, **kwargs), [compute_vec_size, mlir_dtype] + + @staticmethod + def _store(operand, buffer, indices, buffer_shape, *args, buffer_name=None, **kwargs): + compute_vec_size, mlir_dtype = V.kernel.var_info[operand][0], V.kernel.var_info[operand][1] + + if compute_vec_size == 1: + vshape = f"{mlir_dtype}" + operation = "affine.store" + line = f"{operation} %{operand}, %{buffer}[{indices}]" + shape = buffer_shape + else: + vshape = f"vector<{compute_vec_size}x{mlir_dtype}>" + operation = "affine.vector_store" + line = f"{operation} %{operand}, %{buffer}[{indices}]" + shape = f"{buffer_shape}, {vshape}" + line = format_mlir_op(line, shape, **kwargs) + + if buffer_name is not None: + return common.DeferredLine(buffer_name, line), [None, None] + else: + return line, [None, None] + + @staticmethod + def affine_apply(map_var, indices, indirect_dims=None, comment=None, *args, **kwargs): + # Format indices arguments + indices_str = ", ".join([f"%{i}" for i in indices]) + op_str = f"affine.apply #{map_var}({indices_str})" + + # Add indirect dimensions if provided + if indirect_dims: + indirect_str = ", ".join(indirect_dims) + op_str += f"[{indirect_str}] {{indirect_access}}" + if comment: + op_str += f" // {comment}" + return op_str, [1, "index"] + + @staticmethod + def affine_map(dim_names, expr_str, symbol_names=None, comment=None, *args, **kwargs): + # Handle dim_names as list or string + if isinstance(dim_names, list): + dims_str = ", ".join([str(dim) for dim in dim_names]) + else: + dims_str = dim_names + + # Build the map string + if symbol_names: + if isinstance(symbol_names, list): + symbols_str = ", ".join(symbol_names) + else: + symbols_str = symbol_names + map_str = f"affine_map<({dims_str})[{symbols_str}] -> ({expr_str})>" + else: + map_str = f"affine_map<({dims_str}) -> ({expr_str})>" + + if comment: + map_str += f" // {comment}" + + return map_str, [1, "map"] diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 23be941c..22d1011b 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -7,24 +7,27 @@ from PyTorchSimFrontend import extension_config from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel +from torch.utils._ordered_set import OrderedSet from torch._inductor import config from torch._inductor.scheduler import BaseScheduling, FusedSchedulerNode, SchedulerNode, BaseSchedulerNode from torch._inductor.utils import IndentedBuffer from torch._inductor.virtualized import V from torch._inductor.ir import LoopBody from torch._inductor import dependencies +from torch._inductor.codegen.common import BackendFeature from . import mlir_common from . import mlir_lowering # DO NOT REMOVE THIS LINE, it is used for lowering +from . import mlir_decomposition # DO NOT REMOVE THIS LINE, it is used for decomposition class MLIRScheduling(BaseScheduling): count = 0 target_kernel = MLIRKernel def __init__(self, scheduler): self.scheduler = scheduler - self.scheduler.can_fuse_origin = self.scheduler.can_fuse - self.scheduler.can_fuse = self.can_fuse_with_exceptions - #self.scheduler.enter_context = self.enter_context_fixed # FIXME. Monkey patch: For fixing the inductor bug + if scheduler is not None: + self.scheduler.can_fuse_origin = self.scheduler.can_fuse + self.scheduler.can_fuse = self.can_fuse_with_exceptions # FIXME. Monkey patch: For prolouge fusion self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() self._ready_to_flush = False self.outer_function = set() @@ -32,51 +35,26 @@ def __init__(self, scheduler): self.max_fusion_size = 5 def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: + if not extension_config.CONFIG_FUSION_PROLOGUE: + return self.scheduler.can_fuse_origin(node1, node2) + # Extract base template node base_template_node1 = [node for node in node1.get_nodes() if node.is_template()] base_template_node2 = [node for node in node2.get_nodes() if node.is_template()] - if node1.get_device() != node2.get_device(): - return False - if not (isinstance(node1, (SchedulerNode, FusedSchedulerNode)) and isinstance(node2, (SchedulerNode, FusedSchedulerNode))): - return False - if len(base_template_node1) == 1 and len(base_template_node2) == 0 and extension_config.CONFIG_FUSION_REDUCTION_EPILOGUE: - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate - if (isinstance(base_template_node1[0].node.template, MLIRGemmTemplate) or isinstance(base_template_node1[0].node.template, MLIRBMMTemplate)) and node2.is_reduction(): - # For matmul/bmm+reduction case - size_match = node1.get_nodes()[0].node.get_numel() == reduce(operator.mul, node2.get_nodes()[0].node.get_size(), 1) * reduce(operator.mul, node2.get_nodes()[0].node.get_reduction_size(), 1) - target_symbol = symbols("r0") - try: - stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.get_nodes()[0].node).split("\n") if "r0" in i][1] - stride = int(sympify(stride).coeff(target_symbol)) - except: - return False - - # We can't fuse dim=-1 - layout_possible = stride != 1 - # Directed linked? - dependency_check = node2.get_nodes()[0] in [node.node for node in base_template_node1[0].users]# and len(node2.read_writes.reads)==1 - dependency_size = all([i.get_numel() == node1.get_nodes()[0].node.get_numel() for i in node2.read_writes.reads]) - return size_match and layout_possible and dependency_check and dependency_size - - # For prologue fusion case - if extension_config.CONFIG_FUSION_PROLOGUE and len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(base_template_node2) == 1: - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate + # Case 3: Prologue(Pointwise) + Tempalte + if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and not node1.is_reduction() and len(base_template_node2) == 1 and extension_config.CONFIG_FUSION_PROLOGUE: target_node = base_template_node2[0].node - if target_node.origin_node is not None and hasattr(target_node.origin_node.target, "_name") and target_node.origin_node.target._name == 'aten::convolution': - return False - if node1.is_reduction(): + + # Check if template supports prologue fusion + if not getattr(target_node.template, 'support_prologue_fusion', False): return False + if len(node1.read_writes.writes) != 1: return False if node1.node not in target_node.inputs or any(["view" in str(ori) for ori in node1.node.origins]): #FIXME return False - # Currently only BMM, MM support prologue fusion - if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): - return False # We don't fuse this edge case... if base_template_node2[0].group[1][0][0] == 1: return False @@ -84,28 +62,39 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]: node1 = self.revert_group(node1) return True - return self.scheduler.can_fuse_origin(node1, node2) + def _set_flush_status(self, status: bool): self._ready_to_flush = status + def reset_kernel_group(self): + self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() + + def get_backend_features(self, device): + """Return a set of .codegen.common.BackendFeature()""" + return OrderedSet([BackendFeature.REDUCE_TO_SINGLE_ELEMENT]) + def can_fuse_vertical(self, node1, node2): return self.can_fuse_horizontal(node1, node2) + def can_fuse_multi_outputs_template(self, node1, node2): + return self.can_fuse_horizontal(node1, node2) + def can_fuse_horizontal(self, node1, node2): if not extension_config.CONFIG_FUSION: return False + if (len(node1.get_nodes())+ len(node2.get_nodes())) > self.max_fusion_size: return False + _, (vars1, reduce1) = node1.group _, (vars2, reduce2) = node2.group - - # Reduction is currently not supported - if node1.is_reduction() and node2.is_reduction() and not node1.is_template() and not node2.is_template() and extension_config.CONFIG_FUSION_REDUCTION_REDUCTION: - return vars1 == vars2 and reduce1 == reduce2 and node1.inverse_users == node2.inverse_users - if node1.is_reduction() or node2.is_reduction(): - return False + # For input/dependency checks + reads1 = {dep.name for dep in node1.read_writes.reads} + reads2 = {dep.name for dep in node2.read_writes.reads} + writes1 = {dep.name for dep in node1.read_writes.writes} + writes2 = {dep.name for dep in node2.read_writes.writes} # Can't fuse two template node if node1.is_template() and node2.is_template(): @@ -114,17 +103,39 @@ def can_fuse_horizontal(self, node1, node2): if '_unsafe_index' in node1.get_nodes()[0].node.origins or "_unsafe_index" in node2.get_nodes()[0].node.origins: return False - # Check template node fusion - if node1.is_template() or node2.is_template(): + # Extract base template node + base_template_node1 = [node for node in node1.get_nodes() if node.is_template()] + base_template_node2 = [node for node in node2.get_nodes() if node.is_template()] + + # Case 0: Reduction fusion + if ( + node1.is_reduction() + and node2.is_reduction() + and not node1.is_template() + and not node2.is_template() + and extension_config.CONFIG_FUSION_REDUCTION_REDUCTION + ): + # 1) Same loop/iteration domain + same_iter = vars1 == vars2 and reduce1 == reduce2 + # 2) No data dependency between the two reductions + no_dependency = not ( + writes1 & (reads2 | writes2) or writes2 & (reads1 | writes1) + ) + return same_iter and no_dependency + + # Case 1: Template + Pointwise fusion + if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and len(base_template_node2) == 0 and not node2.is_reduction(): # Don't fuse maxpool template code from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate - template_node1 = next((n for n in node1.get_nodes() if n.is_template()), None) - template_node2 = next((n for n in node2.get_nodes() if n.is_template()), None) - if template_node1 and len(node1.get_nodes()) == 1 and isinstance(template_node1.node.template, MLIRMaxPoolTemplate) or \ - template_node2 and len(node2.get_nodes()) == 1 and isinstance(template_node2.node.template, MLIRMaxPoolTemplate): + template_node = base_template_node1[0] + epilogue_node = node2 + + # Check if template supports epilogue fusion + if not getattr(template_node.node.template, 'support_epilogue_fusion', False): + return False + + if isinstance(template_node.node.template, MLIRMaxPoolTemplate): return False # Pointwise check @@ -133,26 +144,76 @@ def can_fuse_horizontal(self, node1, node2): if v1_total != v2_total: return False - # Pattern check - template_node, act_node = (template_node1, node2) if template_node1 else (template_node2, node1) - has_depedency = set(act_node.inverse_users) <= set(template_node.get_nodes()) - if not has_depedency: + # Pattern check: check data dependency between act_node and template_node + template_sched_nodes = list(template_node.get_nodes()) + # Buffers produced by the template (its outputs) + template_writes = { + dep + for n in template_sched_nodes + for dep in n.read_writes.writes + } + # Buffers still required by the activation node (unmet) or read by it + epilogue_unmet = { dep for dep in epilogue_node.unmet_dependencies } + has_dependency = bool(template_writes) and epilogue_unmet.issubset(template_writes) and not bool(reads1 & writes2) + if not has_dependency: return False # Revert act_node.group : simplify_and_reorder() modified _body, _size, group - if template_node.group != act_node.group: + if template_node.group != epilogue_node.group: # We don't fuse this case... - if (isinstance(template_node.node.template, MLIRBMMTemplate) or isinstance(template_node.node.template, MLIRGemmTemplate)) and template_node.group[1][0][0] == 1: + if getattr(template_node.node.template, 'support_prologue_fusion', False) and template_node.group[1][0][0] == 1: return False - if list(template_node.group[1][0]) != list(act_node.get_nodes()[0].node.data.get_size()): + if list(template_node.group[1][0]) != list(epilogue_node.get_nodes()[0].node.data.get_size()): return False - self.revert_group(act_node) + self.revert_group(epilogue_node) return True - # Check elementwise fusion - if vars1 == vars2 and reduce1 == reduce2: - return True + # Case 2: Tempalte + Reduction fusion + if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and len(base_template_node2) == 0 and node2.is_reduction() and extension_config.CONFIG_FUSION_REDUCTION_EPILOGUE: + target_node = base_template_node1[0].node + + # Check if template supports reduction fusion + if not getattr(target_node.template, 'support_reduction_fusion', False): + return False + + size_match = node1.get_nodes()[0].node.get_numel() == reduce(operator.mul, node2.get_nodes()[0].node.get_size(), 1) * reduce(operator.mul, node2.get_nodes()[0].node.get_reduction_size(), 1) + target_symbol = symbols("r0_0") + try: + stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.get_nodes()[0].node).split("\n") if "r0" in i][1] + stride = int(sympify(stride).coeff(target_symbol)) + except: + return False + + # We can't fuse dim=-1 & N == 1 + layout_possible = stride != 1 and (1 not in node1.node.get_size()) + # Directed linked? + dependency_check = writes1 & reads2 + dependency_size = all([i.get_numel() == node1.get_nodes()[0].node.get_numel() for i in node2.read_writes.reads]) + return size_match and layout_possible and dependency_check and dependency_size + + # Case 3: Prologue(Pointwise) + Tempalte + # if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and not node1.is_reduction() and len(base_template_node2) == 1 and extension_config.CONFIG_FUSION_PROLOGUE: + # from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate + # from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate + + # target_node = base_template_node2[0].node + # # Currently only BMM, MM support prologue fusion + # if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): + # return False + + # if len(node1.read_writes.writes) != 1: + # return False + # if node1.node not in target_node.inputs or any(["view" in str(ori) for ori in node1.node.origins]): #FIXME + # return False + + # # We don't fuse this edge case... + # if base_template_node2[0].group[1][0][0] == 1: + # return False + + # if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]: + # node1 = self.revert_group(node1) + # return True return False def revert_group(self, act_nodes, args=None, var_ranges=None): @@ -165,6 +226,8 @@ def revert_group(self, act_nodes, args=None, var_ranges=None): act_node.node.get_store_function(), (args if act_node.node.get_reduction_type() else args[:1]), var_ranges, + args[0], + args[1] ) index_size = [] reduce_size = [] @@ -180,12 +243,13 @@ def revert_group(self, act_nodes, args=None, var_ranges=None): def group_fn(self, sizes): return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) - def codegen_nodes(self, nodes): + def codegen_node(self, _node): + nodes = _node.get_nodes() _, (group, reduction_group) = max( nodes, key=lambda x: int(x.is_reduction()) ).group - # Note: We assume that ther is at least one loop in the nodes + # Note: We assume that there is at least one loop in the nodes # But, inductor simplifies the group, there could be no loop # In that case, we add dummy loop(size=1) to the group if len(group) == 0: @@ -210,17 +274,12 @@ def codegen_nodes(self, nodes): kernel_name_candidate = f"extension_kernel_{MLIRScheduling.count}" MLIRScheduling.count += 1 - src_code = ex_kernel.codegen_nodes(nodes, kernel_name_candidate) - kernel_name = self.define_kernel(src_code, kernel_name_candidate, ex_kernel.vector_lane, - ex_kernel.spad_info, origins= {str(i) for i in nodes[0].node.origins}) + src_code, meta_code = ex_kernel.codegen_nodes(nodes, kernel_name_candidate) + kernel_name = self.define_kernel(src_code, meta_code, kernel_name_candidate, ex_kernel.vector_lane, + ex_kernel.spad_info, origins={str(i) for node in nodes for i in node.node.origins}) ex_kernel.call_kernel(kernel_name) _, args, _, _ = ex_kernel.args.mlir_argdefs() args = ", ".join(args) - eager_mode = int(os.environ.get('TOGSIM_EAGER_MODE', default=False)) - if (eager_mode): - V.graph.wrapper_code.writeline( - f"yield ({kernel_name}, ({args}))" - ) self._set_flush_status(True) def ready_to_flush(self): @@ -230,70 +289,58 @@ def codegen_sync(self): pass def flush(self): - self.kernel_group.codegen_define_and_call(V.graph.wrapper_code) - self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() + src_code = self.kernel_group.codegen_group() + if src_code: + kernel_name = self.define_kernel( + src_code, self.kernel_group.scheduled_nodes + ) + self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name) + self.reset_kernel_group() self._set_flush_status(False) def define_function(self, kernel): partial_code, function_name = kernel.def_function() if partial_code is not None and function_name not in self.outer_function: with V.set_kernel_handler(kernel): - code = partial_code.finalize() + code = partial_code.finalize_all() wrapper = V.graph.wrapper_code wrapper.header.writeline(code) self.outer_function.add(function_name) - def define_kernel(self, src_code, kernel_name, vector_lane, spad_info, loop_size=None, origins={}): + def define_kernel(self, src_code, meta_code, kernel_name, vector_lane, spad_info, loop_size=None, origins={}): wrapper = V.graph.wrapper_code if src_code in wrapper.src_to_kernel: kernel_name = wrapper.src_to_kernel[src_code] else: wrapper.src_to_kernel[src_code] = kernel_name - codecache_def = IndentedBuffer() codecache_def.writeline(f"custom_async_compile.mlir('''{src_code}''', ") codecache_def.writeline(f"vectorlane_size={vector_lane},") codecache_def.writeline(f"loop_size={loop_size},") codecache_def.writeline(f"spad_info={spad_info},") codecache_def.writeline(f"origins={origins},") - codecache_def.writeline("arg_attributes=arg_attributes,") + codecache_def.writeline(f"arg_attributes={meta_code},") codecache_def.writeline(f"vlen={extension_config.vpu_vector_length_bits})") - wrapper.define_kernel(kernel_name, codecache_def.getvalue(), cuda=False) + wrapper.define_kernel(kernel_name, codecache_def.getvalue(), gpu=False) return kernel_name - def codegen_template(self, template_node, epilogue_nodes): - # Handle prologue pattern - prologue_nodes = [] - if not template_node.is_template(): - epilogue_nodes = [template_node] + epilogue_nodes - for i, node in enumerate(epilogue_nodes): - if node.is_template(): - template_node = node - prologue_nodes = epilogue_nodes[:i] - epilogue_nodes = epilogue_nodes[i+1:] - break - + def codegen_template(self, template_node, epilogue_nodes, prologue_nodes): # Generate template code template_buffer = template_node.node kernel, tile_candidates, render = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_group=self.kernel_group) _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() - src_code = kernel.codegen_nodes(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) + src_code, meta_code = kernel.codegen_nodes(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) - with V.set_kernel_handler(kernel): - kernel_name = self.define_kernel(src_code, kernel.kernel_name, kernel.vector_lane, kernel.spad_info, - kernel.loop_size, origins={str(i) for i in template_node.node.origins}) + with kernel: + all_nodes = [template_node] + (epilogue_nodes or []) + (prologue_nodes or []) + origins = {str(i) for n in all_nodes for i in n.node.origins} + kernel_name = self.define_kernel(src_code, meta_code, kernel.kernel_name, kernel.vector_lane, kernel.spad_info, + kernel.loop_size, origins=origins) self.define_function(kernel) kernel.call_kernel(kernel_name) V.graph.removed_buffers |= kernel.removed_buffers _, args, _, _ = self.kernel_group.args.mlir_argdefs() - eager_mode = int(os.environ.get('TOGSIM_EAGER_MODE', default=False)) - if (eager_mode): - target_kernel_name = kernel_name if kernel.outer_func_name is None else kernel.outer_func_name + f"_{len(args)}" - args = ", ".join(args) - V.graph.wrapper_code.writeline( - f"yield ({target_kernel_name}, ({args}))" - ) self._set_flush_status(True) def enter_context_fixed(self, node): diff --git a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py new file mode 100644 index 00000000..a3ae6192 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py @@ -0,0 +1,550 @@ +import math # sqrt +import sympy + +from typing import List, Optional + +import torch +from torch import empty_strided +from torch._inductor.ir import IRNode, TensorBox, FixedLayout +from torch._inductor.virtualized import V +from torch._inductor.select_algorithm import realize_inputs +from torch.backends.cuda import flash_sdp_enabled, mem_efficient_sdp_enabled + +from PyTorchSimFrontend import extension_config +from PyTorchSimFrontend.mlir import mlir_common +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel + + +def _make_offset_map_with_sym(strides, sym_dim, sym_stride, offset=0): + """Like _make_offset_map but injects a block symbol ``s`` into dimension ``sym_dim``. + + The effective index for that dimension becomes ``d{sym_dim} + sym_stride * s``. + Use this to keep ``affine.for`` bounds static and encode the block contribution + directly inside the ``affine.apply`` call that computes the DRAM offset. + + Args: + strides: per-dimension DRAM strides. + sym_dim: which dimension carries the block symbol. + sym_stride: multiplier for the symbol (1 for abs-position loops like FLASH + ``%blk``; ``BlkS`` for block-index loops like PARTIAL ``%blk``). + offset: constant layout offset. + + Returns: + MLIR affine_map string with one symbol, e.g. + ``affine_map<(d0, d1, d2)[s] -> (d0 * 8192 + (d1 + 128 * s) * 64 + d2)>`` + """ + n = len(strides) + terms = [] + for j, sv in enumerate(strides): + sv = int(sv) + if sv == 0: + continue + if j == sym_dim: + inner = f"d{j} + s" if sym_stride == 1 else f"d{j} + {sym_stride} * s" + terms.append(f"({inner})" if sv == 1 else f"({inner}) * {sv}") + else: + terms.append(f"d{j}" if sv == 1 else f"d{j} * {sv}") + try: + off = int(offset) + except (TypeError, ValueError): + off = 0 + if off: + terms.append(str(off)) + dim_str = ", ".join(f"d{j}" for j in range(n)) + expr = " + ".join(terms) if terms else "0" + return f"affine_map<({dim_str})[s] -> ({expr})>" + + +def _make_offset_map(strides, offset=0): + """Generate an MLIR affine_map string for a flat DRAM base-address. + + Args: + strides: list of integer per-dimension strides. + A stride of 0 means the dimension does not contribute. + offset: constant layout offset (e.g. from IRNode.get_layout().offset). + + Returns: + MLIR affine_map string, e.g. ``affine_map<(d0, d1) -> (d0 * 128 + d1)>`` + """ + n = len(strides) + terms = [] + for j, s in enumerate(strides): + s = int(s) + if s == 1: + terms.append(f"d{j}") + elif s != 0: + terms.append(f"d{j} * {s}") + try: + off = int(offset) + except (TypeError, ValueError): + off = 0 + if off: + terms.append(str(off)) + dim_str = ", ".join(f"d{j}" for j in range(n)) + expr = " + ".join(terms) if terms else "0" + return f"affine_map<({dim_str}) -> ({expr})>" + + +def flash_sdpa_args( + query : TensorBox, + key : TensorBox, + value : TensorBox) -> list: + """ + Arg processing for flash SDPA. + Its logic is based on: + mm_args() which is in torch._inductor.kernel.mm_common.py (142 line). + """ + + # Materialize input buffers for the codegen backend. + query, key, value = realize_inputs(query, key, value) + + # query : (n, hq, l, e) + # key : (n, h, s, e) + # value : (n, h, s, ev) + # out : (n, hq, l, ev) + # n: Batch size + # hq: query's head counts, h: key and value's head counts. + # l: target sequence lenght and s: source sequence length. + # e: embeding dimension of the query and key and ev: embeding dimension of the value. + nq, hq, l, eq = query.get_size() + nk, hk, sk, ek = key.get_size() + nk, hv, sv, ev = value.get_size() + + n = V.graph.sizevars.guard_equals(nq, nk) + n = V.graph.sizevars.guard_equals(nq, nk) + + h = V.graph.sizevars.guard_equals(hk, hv) + s = V.graph.sizevars.guard_equals(sk, sv) + e = V.graph.sizevars.guard_equals(eq, ek) + + # While there are no theoretical requirements for e == ev, + # this implementation currently enforces e == ev for simplicity. + if e != ev: + raise NotImplementedError( + "Flash SDPA currently requires matching head dimensions between query and value (e == ev)." + ) + + # Minimal GQA support (single-batch only for now). + # We map each query head to a KV head by grouping: hq = g * h. + if hq != h: + if n != 1: + raise NotImplementedError("Flash SDPA GQA is currently supported only for n == 1.") + if (hq % h) != 0: + raise NotImplementedError(f"Flash SDPA GQA requires hq % h == 0 (hq: {hq}, h: {h}).") + + layout = FixedLayout( + query.get_device(), + query.get_dtype(), + [n, hq, l, ev] + ) + + return [n, hq, h, l, s, e, ev, layout, query, key, value] + +def calculate_scale(query: torch.Tensor, scale: float) -> float: + """ + Calculate the scaling factor based on the head dimension if scale is None + Otherwise, use the provided scale. + """ + if scale is None: + return 1.0 / math.sqrt(query.layout.size[-1]) + else: + return scale + + +FLASH_SDPA_TEMPLATE = r""" +// SDPA kernel +// b = {{ b }} +// l = {{ l }} +// s = {{ s }} +// e = {{ e }} +// tile_l = {{ tile_l }} +// tile_s = {{ tile_s }} +// tile_e = {{ tile_e }} +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[query, key, value], outputs=[out], names_str="query, key, value, out", input_reorder=input_reorder)}} { + // Inputs + {{ kernel.def_sram_buffer("query", q_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("key", k_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("value", v_tile_desc, indent_size=2) }} + + // Output + {{ kernel.def_sram_buffer("out", out_tile_desc, indent_size=2) }} + + // Intermediate buffers + {{ kernel.def_sram_buffer("mul", mul_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("max", max_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("sum", sum_desc, indent_size=2) }} + + // Constants + %c0 = arith.constant 0.0 : {{ data_stype }} + %c1 = arith.constant 1.0 : {{ data_stype }} + %c_scale = arith.constant {{ scale }} : {{ data_stype }} + %c_neg_inf = arith.constant -1.0e+30 : {{ data_stype }} + + %v0_c = arith.constant dense<0.0> : vector<{{ chunk_size }}x{{ data_stype }}> + %v0_l = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(tile_l, tile_e) }}x{{ data_stype }}> + %v0_s = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(tile_s, tile_l) }}x{{ data_stype }}> + %v0_2x = arith.constant dense<0.0> : vector<2x{{ data_stype }}> + + %v_neg_inf_c = arith.constant dense<-1.0e+30> : vector<{{ chunk_size }}x{{ data_stype }}> + %v_neg_inf_2x = arith.constant dense<-1.0e+30> : vector<2x{{ data_stype }}> + + %v_scale = vector.broadcast %c_scale : {{ data_stype }} to vector<{{ tile_s }}x{{ data_stype }}> + + {{ kernel.def_local_vars(indent_size=2) }} + + affine.for %index0 = 0 to {{ b }} { + affine.for %index3 = 0 to 1 step 1 { + affine.for %index1 = 0 to {{ l }} step {{ tile_l }} { + %q_dram_offset = affine.apply {{ q_offset_map }}(%index0, %index1, %index3) + {{ kernel.def_dma_op("MVIN", "query", [], q_tile_desc, indent_size=8, dram_stride=q_dram_stride, dram_offset="q_dram_offset") }} + + affine.vector_store %v0_l, %out_buffer[0, 0, 0] : {{ out_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ kernel.get_spad_size_per_lane(tile_l, tile_e) }}x{{ data_stype }}> + affine.vector_store %v_neg_inf_2x, %max_buffer[0, 0] : {{ max_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + affine.vector_store %v0_2x, %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + + %qt_buffer2D = memref.reinterpret_cast %q_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_l }}], strides: [{{ tile_l }}, 1] : {{ q_tile_desc.get_mlir_shape(data_stype) }} to memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1> + %ot_buffer2D = memref.reinterpret_cast %out_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_l }}], strides: [{{ tile_l }}, 1] : {{ out_tile_desc.get_mlir_shape(data_stype) }} to memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1> + + affine.for %index2 = 0 to {{ s }} step {{ tile_s }} { + %k_dram_offset = affine.apply {{ k_offset_map }}(%index0, %index2, %index3) + {{ kernel.def_dma_op("MVIN", "key", [], k_tile_desc, indent_size=10, dram_stride=k_dram_stride, dram_offset="k_dram_offset") }} + %v_dram_offset = affine.apply {{ v_offset_map }}(%index0, %index2, %index3) + {{ kernel.def_dma_op("MVIN", "value", [], v_tile_desc, indent_size=10, dram_stride=v_dram_stride, dram_offset="v_dram_offset") }} + + affine.vector_store %v0_s, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ kernel.get_spad_size_per_lane(tile_s, tile_l) }}x{{ data_stype }}> + + %k_buffer2D = memref.reinterpret_cast %k_buffer to offset: [0], sizes: [{{ tile_s }}, {{ tile_e }}], strides: [{{ tile_e }}, 1] : {{ k_tile_desc.get_mlir_shape(data_stype) }} to memref<{{ tile_s }}x{{ tile_e }}x{{ data_stype }}, 1> + %vt_buffer2D = memref.reinterpret_cast %v_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_s }}], strides: [{{ tile_s }}, 1] : {{ v_tile_desc.get_mlir_shape(data_stype) }} to memref<{{ tile_e }}x{{ tile_s }}x{{ data_stype }}, 1> + + + // key @ query.t and scaling. + linalg.matmul + { idx_map = array } + ins(%k_buffer2D, %qt_buffer2D : memref<{{ tile_s }}x{{ tile_e }}x{{ data_stype }}, 1>, memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>) + outs(%mul_buffer : {{ mul_tile_desc.get_mlir_shape(data_stype) }}) + + %raw_mul_vec = affine.vector_load %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ tile_s }}x{{ data_stype }}> + %scaled_mul_vec = arith.mulf %raw_mul_vec, %v_scale : vector<{{ tile_s }}x{{ data_stype }}> + affine.vector_store %scaled_mul_vec, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ tile_s }}x{{ data_stype }}> + + + // Find new max. + %old_max = affine.vector_load %max_buffer[0,0] : {{ max_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + + %chunk_max_res = affine.for %index5 = 0 to {{ tile_s }} step {{ chunk_size }} iter_args(%iter_max=%v_neg_inf_c) -> (vector<{{ chunk_size }}x{{ data_stype }}>) { + %chunk_val = affine.vector_load %mul_buffer[0, %index5] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ chunk_size }}x{{ data_stype }}> + %local_max = arith.maximumf %chunk_val, %iter_max : vector<{{ chunk_size }}x{{ data_stype }}> + affine.yield %local_max : vector<{{ chunk_size }}x{{ data_stype }}> + } { accumulation_loop=true } + + %max_cast = vector.shape_cast %chunk_max_res : vector<{{ chunk_size }}x{{ data_stype }}> to vector<{{ chunk_size // 2 }}x2x{{ data_stype }}> + %max_reduced_1 = vector.multi_reduction , %max_cast, %v_neg_inf_2x [0] : vector<8x2x{{ data_stype }}> to vector<2x{{ data_stype }}> + %max_shuffled = vector.shuffle %max_reduced_1, %max_reduced_1 [1, 0] : vector<2x{{ data_stype }}>, vector<2x{{ data_stype }}> + %max_reduced_2 = arith.maximumf %max_reduced_1, %max_shuffled : vector<2x{{ data_stype }}> + + %new_max = arith.maximumf %max_reduced_2, %old_max : vector<2x{{ data_stype }}> + affine.vector_store %new_max, %max_buffer[0, 0] : {{ max_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + + + // Compute rescale factors: exp(old_max - new_max) + %max_diff = arith.subf %old_max, %new_max : vector<2x{{ data_stype }}> + %max_diff_scalar = vector.extract %max_diff[0] : {{ data_stype }} from vector<2x{{ data_stype }}> + + %rescale_bcast_e = vector.broadcast %max_diff_scalar : {{ data_stype }} to vector<{{ tile_e }}x{{ data_stype }}> + %exp_rescale_e = math.exp %rescale_bcast_e : vector<{{ tile_e }}x{{ data_stype }}> + + %rescale_bcast_2 = vector.broadcast %max_diff_scalar : {{ data_stype }} to vector<2x{{ data_stype }}> + %exp_rescale_2 = math.exp %rescale_bcast_2 : vector<2x{{ data_stype }}> + + + // Rescale previous out and sum accumulators + %old_out = affine.vector_load %ot_buffer2D[0, 0] : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>, vector<{{ tile_e }}x{{ data_stype }}> + %rescaled_out = arith.mulf %exp_rescale_e, %old_out : vector<{{ tile_e }}x{{ data_stype }}> + affine.vector_store %rescaled_out, %ot_buffer2D[0, 0] : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>, vector<{{ tile_e }}x{{ data_stype }}> + + %old_sum = affine.vector_load %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + %rescaled_sum = arith.mulf %old_sum, %exp_rescale_2 : vector<2x{{ data_stype }}> + + + // Shift scores and apply exp: exp(x - new_max) + %scaled_scores_reload = affine.vector_load %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ tile_s }}x{{ data_stype }}> + %new_max_scalar = vector.extract %new_max[0] : {{ data_stype }} from vector<2x{{ data_stype }}> + %new_max_bcast = vector.broadcast %new_max_scalar : {{ data_stype }} to vector<{{ tile_s }}x{{ data_stype }}> + + %shifted_scores = arith.subf %scaled_scores_reload, %new_max_bcast : vector<{{ tile_s }}x{{ data_stype }}> + %exp_scores = math.exp %shifted_scores : vector<{{ tile_s }}x{{ data_stype }}> + affine.vector_store %exp_scores, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ tile_s }}x{{ data_stype }}> + + + // accumulate current sum + %chunk_sum_res = affine.for %index5 = 0 to {{ tile_s }} step {{ chunk_size }} iter_args(%iter_sum=%v0_c) -> (vector<{{ chunk_size }}x{{ data_stype }}>) { + %chunk_exp = affine.vector_load %mul_buffer[0, %index5] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ chunk_size }}x{{ data_stype }}> + %local_sum = arith.addf %chunk_exp, %iter_sum : vector<{{ chunk_size }}x{{ data_stype }}> + affine.yield %local_sum : vector<{{ chunk_size }}x{{ data_stype }}> + } { accumulation_loop=true } + + %zero_2x = vector.broadcast %c0 : {{ data_stype }} to vector<2x{{ data_stype }}> + %sum_cast = vector.shape_cast %chunk_sum_res : vector<{{ chunk_size }}x{{ data_stype }}> to vector<{{ chunk_size // 2 }}x2x{{ data_stype }}> + %sum_reduced_1 = vector.multi_reduction , %sum_cast, %zero_2x [0] : vector<8x2x{{ data_stype }}> to vector<2x{{ data_stype }}> + %sum_shuffled = vector.shuffle %sum_reduced_1, %sum_reduced_1 [1, 0] : vector<2x{{ data_stype }}>, vector<2x{{ data_stype }}> + %sum_reduced_2 = arith.addf %sum_reduced_1, %sum_shuffled : vector<2x{{ data_stype }}> + + %new_sum = arith.addf %sum_reduced_2, %rescaled_sum : vector<2x{{ data_stype }}> + affine.vector_store %new_sum, %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + + + // value.t @ mul + linalg.matmul + { idx_map = array } + ins(%vt_buffer2D, %mul_buffer : memref<{{ tile_e }}x{{ tile_s }}x{{ data_stype }}, 1>, {{ mul_tile_desc.get_mlir_shape(data_stype) }}) + outs(%ot_buffer2D : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>) + } { accumulation_loop=true } + + // out @ row_sum^(-1) + %final_row_sum = affine.vector_load %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + %one_2x = vector.broadcast %c1 : {{ data_stype }} to vector<2x{{ data_stype }}> + + %reciprocal_row_sum_2x = arith.divf %one_2x, %final_row_sum : vector<2x{{ data_stype }}> + %reciprocal_scalar = vector.extract %reciprocal_row_sum_2x[0] : {{ data_stype }} from vector<2x{{ data_stype }}> + %reciprocal_bcast_e = vector.broadcast %reciprocal_scalar : {{ data_stype }} to vector<{{ tile_e }}x{{ data_stype }}> + + %accumulated_out = affine.vector_load %ot_buffer2D[0, 0] : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>, vector<{{ tile_e }}x{{ data_stype }}> + %stable_final_out = arith.mulf %accumulated_out, %reciprocal_bcast_e : vector<{{ tile_e }}x{{ data_stype }}> + affine.vector_store %stable_final_out, %ot_buffer2D[0, 0] : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>, vector<{{ tile_e }}x{{ data_stype }}> + + %out_dram_offset = affine.apply {{ out_offset_map }}(%index0, %index1, %index3) + {{ kernel.def_dma_op("MVOUT", "out", [], out_tile_desc, indent_size=8, dram_stride=out_dram_stride, dram_offset="out_dram_offset") }} + } { outer_loop=true } + } { outer_loop=true } + } { outer_loop=true } + return +} +""" + +class MLIRFlashSDPATemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, scale, input_reorder=None): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.scale = scale + + def render(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + prologue_nodes: Optional[List[IRNode]] = None, + tile_info = None, + **kwargs): + + # Except for kernel, other arguments are usually None. + query, key, value, out, q_tensor, k_tensor, v_tensor, out_tensor, b, l, s, e, ev, n_extra_node, n_prologue_node = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) + + if tile_info is None: + tile_l, tile_s, tile_e, subtile_l, subtile_s, subtile_e = self.select_tile(kernel, l, s, e, n_extra_node, 0, n_prologue_node)[0] + else: + tile_l, tile_s, tile_e, subtile_l, subtile_s, subtile_e = tile_info + + TOG_latency = l if tile_l > l else tile_l + kernel.loop_size = [TOG_latency, tile_s, tile_e] + + # Select template code + # Other templates will be added according to situations. + nr_reduction_nodes = [node for node in epilogue_nodes if node.is_reduction()] if epilogue_nodes is not None else [] + if nr_reduction_nodes: + raise NotImplementedError("FLASH_SDPA_REDUCTION_TEMPLATE is not implemented yet.") + elif prologue_nodes: + raise NotImplementedError("FLASH_SDPA_PROLOGUE_TEMPLATE is not implemented yet.") + else: + template = FLASH_SDPA_TEMPLATE + epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2", "index3": "index3"} + nr_rdim = 0 + + # Prepare tile descriptors for input and output tensors. + # Intermediate buffers (transient data) do not require DRAM settings(dram stride and dram indices) + # as they are not synchronized with external DRAM. + # DRAM and SRAM tile shapes must match. + vlane_stride = 1 + + # (n, l, s, e, ev) + loop_dim = [sympy.Symbol("index0"), sympy.Symbol("index1"), sympy.Symbol("index2"), sympy.Symbol("index3")] + + + # Hardware constraint: The tile split axis is restricted. + # To accommodate this, we compute (key @ query.t) instead of (query @ key.t). + # SRAM settings + vlane_split_axis = 1 + q_tile_size = [1, tile_l, tile_e] + q_tile_stride = [0, tile_e, 1] + q_tile_desc = mlir_common.MLIRMultiDimTile(q_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + q_tile_desc.set_tile_size_stride(q_tile_size, q_tile_stride) + q_tile_desc.set_name("q_buffer") + q_tile_desc.offset = query.get_layout().offset + # DRAM settings + q_stride = q_tensor.stride() + + # Since we use a weight-stationary approach in the Systolic Array (SA), + # the split axis of the first operand differs from a standard linear algebra matmul. + # The first operand (key) must be split along the column axis. + # This logic aligns with the relationship between the dot product's summation direction and the hardware's accumulation direction in the SA. + # SRAM settings + vlane_split_axis = 2 + k_tile_size = [1, tile_s, tile_e] + k_tile_stride = [0, 1, tile_s] + k_tile_desc = mlir_common.MLIRMultiDimTile(k_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + k_tile_desc.set_tile_size_stride(k_tile_size, k_tile_stride) + k_tile_desc.set_name("k_buffer") + k_tile_desc.offset = key.get_layout().offset + # DRAM settings + k_stride = k_tensor.stride() + + # Since we compute mul = key @ query.t, we perform out.t = (value.t @ Softmax(mul).t).t, + # which simplifies to (value.t @ Softmax(mul)) + # SRAM settings + vlane_split_axis = 1 + v_tile_size = [1, tile_s, tile_e] + v_tile_stride = [0, tile_e, 1] + v_tile_desc = mlir_common.MLIRMultiDimTile(v_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + v_tile_desc.set_tile_size_stride(v_tile_size, v_tile_stride) + v_tile_desc.set_name("v_buffer") + v_tile_desc.offset = value.get_layout().offset + # DRAM settings + v_stride = v_tensor.stride() + + # Output is also stored in transposed format to match the value.t @ Softmax(mul) operation. + # SRAM settings + vlane_split_axis = 1 + out_tile_size = [1, tile_l, tile_e] + out_tile_stride=[0, tile_e, 1] + out_tile_desc = mlir_common.MLIRMultiDimTile(out_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + out_tile_desc.set_tile_size_stride(out_tile_size, out_tile_stride) + out_tile_desc.set_name("out_buffer") + # DRAM settings + out_stride = out.get_layout().stride[1:] + + # Intermediate buffers + + # For mul = key @ query.t + vlane_split_axis = 1 + mul_tile_size = [tile_s, tile_l] + mul_tile_stride = [tile_l, 1] + mul_tile_desc = mlir_common.MLIRMultiDimTile(mul_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + mul_tile_desc.set_tile_size_stride(mul_tile_size, mul_tile_stride) + mul_tile_desc.set_name("mul_buffer") + #FIXME. What is the offset? -> It doesn't matter at this time. + + # For storing maximum values per row + vlane_split_axis = 0 + max_size = [tile_l, 2] + max_stride = [2, 1] + max_desc = mlir_common.MLIRMultiDimTile(max_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + max_desc.set_tile_size_stride(max_size, max_stride) + max_desc.set_name("max_buffer") + + # For storing summation per row + vlane_split_axis = 0 + sum_size = [tile_l, 2] + sum_stride = [2, 1] + sum_desc = mlir_common.MLIRMultiDimTile(sum_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + sum_desc.set_tile_size_stride(sum_size, sum_stride) + sum_desc.set_name("sum_buffer") + + # For reduction + chunk_size = 16 + + # DMA strides and offset affine maps (dram_stride + dram_offset style) + q_dram_stride = [int(q_stride[0]), int(q_stride[1]), int(q_stride[2])] + k_dram_stride = [int(k_stride[0]), int(k_stride[1]), int(k_stride[2])] + v_dram_stride = [int(v_stride[0]), int(v_stride[1]), int(v_stride[2])] + out_dram_stride = [int(out_stride[0]), int(out_stride[1]), int(out_stride[2])] + + q_offset_map = _make_offset_map(q_dram_stride, q_tile_desc.offset) + k_offset_map = _make_offset_map(k_dram_stride, k_tile_desc.offset) + v_offset_map = _make_offset_map(v_dram_stride, v_tile_desc.offset) + out_offset_map = _make_offset_map(out_dram_stride, 0) + + # Keep out_idx only for epilogue_info (not in render_options) + out_idx = [loop_dim[0]*out_stride[0], loop_dim[1]*out_stride[1], loop_dim[3]*out_stride[2]] + + kernel.render_options = dict( + KERNEL_NAME = self.name, + kernel = kernel, + b = b, + l = l, + s = s, + e = e, # Input sizes (dram) + tile_l = tile_l, + tile_s = tile_s, + tile_e = tile_e, # Tile sizes (sram) + data_stype="f32", + query = query, + key = key, + value = value, + out = out, # Inputs and output (dram) + q_dram_stride = q_dram_stride, + k_dram_stride = k_dram_stride, + v_dram_stride = v_dram_stride, + out_dram_stride = out_dram_stride, # Per-dim DRAM strides + q_offset_map = q_offset_map, + k_offset_map = k_offset_map, + v_offset_map = v_offset_map, + out_offset_map = out_offset_map, # Affine maps for base address + q_tile_desc = q_tile_desc, + k_tile_desc = k_tile_desc, + v_tile_desc = v_tile_desc, + mul_tile_desc = mul_tile_desc, + out_tile_desc = out_tile_desc, # Tile descriptions (sram) + max_desc = max_desc, + sum_desc = sum_desc, # Intermediate buffer descriptions (sram) + scale = self.scale, + chunk_size = chunk_size, + input_reorder = self.input_reorder # ETC + ) + + code = self._template_from_string(template).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["l"], kernel.render_options["s"], kernel.render_options["e"]], [kernel.render_options["tile_l"], kernel.render_options["tile_s"], kernel.render_options["tile_e"]]) + return code + + def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): + if template_buffer_node is not None: + self.output_node = template_buffer_node + + query = self.input_nodes[0] + key = self.input_nodes[1] + value = self.input_nodes[2] + out = self.output_node + + q_tensor = empty_strided(query.layout.size, query.layout.stride) + k_tensor = empty_strided(key.layout.size, key.layout.stride) + v_tensor = empty_strided(value.layout.size, value.layout.stride) + out_tensor = empty_strided(out.layout.size, out.layout.stride) + + # Flatten batch and head dimensions (n, h) into a single dimension (b = n*h) + q_tensor = q_tensor.view([-1, q_tensor.shape[-2], q_tensor.shape[-1]]) + k_tensor = k_tensor.view([-1, k_tensor.shape[-2], k_tensor.shape[-1]]) + v_tensor = v_tensor.view([-1, v_tensor.shape[-2], v_tensor.shape[-1]]) + out_tensor = out_tensor.view([-1, out_tensor.shape[-2], out_tensor.shape[-1]]) + + b, l, s, e, ev = q_tensor.size(0), q_tensor.size(1), k_tensor.size(1), k_tensor.size(2), v_tensor.size(2) + + n_extra_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 + n_prologue_node = len(prologue_nodes) if prologue_nodes is not None else 0 + + return query, key, value, out, q_tensor, k_tensor, v_tensor, out_tensor, b, l, s, e, ev, n_extra_node, n_prologue_node + + # Reuse the existing function in MLIRBMMTemplate. + def select_tile(self, kernel, l, s, e, n_extra_node, n_extra_read, n_prologue_node): + + # FIXME: Update the method for getting tile candidates once TestDmaFineGrained oass works correctly with Flash Attention. + # tile_candidates = kernel.flash_sdpa_mapping(l, s, e, n_extra_node=n_extra_node) + tile_candidates = [[kernel.vector_lane, kernel.vector_lane, e]] + + for idx, (tile_l, tile_s, tile_e) in enumerate(tile_candidates): + subtile_l = tile_l if (tile_l < kernel.vector_lane) or n_prologue_node else kernel.vector_lane + subtile_s = tile_s # if (tile_s < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + subtile_e = tile_e # if (tile_e < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + + tile_candidates[idx] = tile_l,tile_s,tile_e,subtile_l,subtile_s,subtile_e + + return tile_candidates + diff --git a/PyTorchSimFrontend/mlir/mlir_sort_template.py b/PyTorchSimFrontend/mlir/mlir_sort_template.py new file mode 100644 index 00000000..24b3a460 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_sort_template.py @@ -0,0 +1,474 @@ +from typing import List, Optional +import contextlib + +from torch._inductor.ir import Buffer, IRNode +from torch._inductor.virtualized import _ops as ops +from torch._inductor.codegen import common + +from PyTorchSimFrontend.mlir import mlir_common +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel +from PyTorchSimFrontend.mlir.mlir_common import LoopLevel + +VECTOR_SIZE = 16 + + +TEMPLATE = r""" +{{kernel.def_global_vars()}} +// chunk index -> element index +#map_chunk_to_elem = affine_map<(d0) -> (d0 * {{ VECTOR_SIZE }})> + +func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X, XI], outputs=[YV], names_str=NAMES_STR, input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_TILE_DESC, id=0, indent_size=2) }} + {{ kernel.def_sram_buffer("XI", XI_TILE_DESC, id=1, indent_size=2) }} + {{ kernel.def_sram_buffer("YV", YV_TILE_DESC, id=2, indent_size=2) }} + {{ kernel.def_local_vars(indent_size=2) }} + + + affine.for %sort_block = 0 to 1 step 1 { + {%- for d in range(RANK-1) %} + affine.for %index{{ OUTPUT_DIM[d] }} = 0 to {{ OUTPUT_SIZES[d] }} step {{ STEP_SIZES[d] }} { + {%- endfor %} + + %x_dram_offset = affine.apply {{ X_OFFSET_MAP }}({{ OUTER_VARS }}) + %xi_dram_offset = affine.apply {{ XI_OFFSET_MAP }}({{ OUTER_VARS }}) + %yv_dram_offset = affine.apply {{ YV_OFFSET_MAP }}({{ OUTER_VARS }}) + {{ kernel.def_dma_op("MVIN", "X", [], X_TILE_DESC, indent_size=INDENT_SIZE, dram_stride=X_DRAM_STRIDE, dram_offset="x_dram_offset") }} + + // SIMD local sort + loop-based chunk merge. +{{ BITONIC_BODY }} + + {{ kernel.def_dma_op("MVOUT", "XI", [], XI_TILE_DESC, indent_size=INDENT_SIZE, dram_stride=XI_DRAM_STRIDE, dram_offset="xi_dram_offset") }} + {{ kernel.def_dma_op("MVOUT", "YV", [], YV_TILE_DESC, indent_size=INDENT_SIZE, dram_stride=YV_DRAM_STRIDE, dram_offset="yv_dram_offset") }} + {%- for d in range(RANK-1) %} + } { outer_loop=true } + {%- endfor %} + } { outer_loop=true } + return +} +""" + + +def _make_offset_map(outer_dims, all_strides, layout_offset): + """Build an affine_map over outer-dim loop variables that computes the flat DRAM offset.""" + terms = [] + for j, d in enumerate(outer_dims): + s = int(all_strides[d]) + if s == 1: + terms.append(f"d{j}") + elif s != 0: + terms.append(f"d{j} * {s}") + try: + off = int(layout_offset) + except (TypeError, ValueError): + off = 0 + if off: + terms.append(str(off)) + nd = len(outer_dims) + dim_str = ", ".join(f"d{j}" for j in range(nd)) + expr = " + ".join(terms) if terms else "0" + return f"affine_map<({dim_str}) -> ({expr})>" + + +def _compute_bitonic_stages(n: int, descending: bool): + stages = [] + size = 2 + while size <= n: + stride = size // 2 + while stride >= 1: + merged_shuffle = list(range(n)) + merged_mask = [None] * n + for start in range(0, n, size): + blk_dir = "ASCENDING" if (start // size) % 2 == 0 else "DESCENDING" + for i in range(start, start + size - stride, stride * 2): + for j2 in range(stride): + a, b = i + j2, i + j2 + stride + merged_shuffle[a] = b + merged_shuffle[b] = a + if blk_dir == "ASCENDING": + merged_mask[a] = True + merged_mask[b] = False + else: + merged_mask[a] = False + merged_mask[b] = True + select_min = [bool(x) if x is not None else False for x in merged_mask] + if descending: + select_min = [not x for x in select_min] + stages.append({"shuffle": merged_shuffle, "select_min": select_min}) + stride //= 2 + size *= 2 + return stages + + +def _pair_less_equal(left_v, right_v, left_i, right_i): + cmp_val = ops.lt(left_v, right_v) + cmp_eq = ops.eq(left_v, right_v) + cmp_idx = ops.le(left_i, right_i) + return ops.or_(cmp_val, ops.and_(cmp_eq, cmp_idx)) + + +def _pair_greater_equal(left_v, right_v, left_i, right_i): + cmp_val = ops.gt(left_v, right_v) + cmp_eq = ops.eq(left_v, right_v) + cmp_idx = ops.le(left_i, right_i) + return ops.or_(cmp_val, ops.and_(cmp_eq, cmp_idx)) + + +def _bitonic_sort_pair(values, indices, vector_size: int, descending: bool, stable_sort: bool): + cur_v = values + cur_i = indices + for stage_desc in _compute_bitonic_stages(vector_size, descending): + mask = ops.constant_mask(stage_desc["select_min"], vector_size) + shuf_v = ops.vector_shuffle(cur_v, stage_desc["shuffle"]) + shuf_i = ops.vector_shuffle(cur_i, stage_desc["shuffle"]) + if stable_sort: + # `cmp` drives the "min side" selection in the bitonic network. + # For descending stable sort, tie elements with smaller original index + # must stay earlier, so the min side should treat larger index as smaller. + if descending: + cmp_val = ops.lt(cur_v, shuf_v) + cmp_eq = ops.eq(cur_v, shuf_v) + cmp_idx = ops.ge(cur_i, shuf_i) + cmp = ops.or_(cmp_val, ops.and_(cmp_eq, cmp_idx)) + else: + cmp = _pair_less_equal(cur_v, shuf_v, cur_i, shuf_i) + else: + cmp = ops.le(cur_v, shuf_v) + min_v = ops.where(cmp, cur_v, shuf_v) + min_i = ops.where(cmp, cur_i, shuf_i) + max_v = ops.where(cmp, shuf_v, cur_v) + max_i = ops.where(cmp, shuf_i, cur_i) + cur_v = ops.where(mask, min_v, max_v) + cur_i = ops.where(mask, min_i, max_i) + return cur_v, cur_i + + +def _merge_sorted_pair_vectors( + left_norm, + left_idx_norm, + right_norm, + right_idx_norm, + ascending: bool, + stable_sort: bool, + vector_size: int, + rev_indices, +): + right_pair = ops.vector_shuffle(right_norm, rev_indices, right_norm) + right_idx_pair = ops.vector_shuffle(right_idx_norm, rev_indices, right_idx_norm) + if ascending: + cmp = ( + _pair_less_equal(left_norm, right_pair, left_idx_norm, right_idx_pair) + if stable_sort + else ops.le(left_norm, right_pair) + ) + else: + cmp = ( + _pair_greater_equal(left_norm, right_pair, left_idx_norm, right_idx_pair) + if stable_sort + else ops.ge(left_norm, right_pair) + ) + left_merge = ops.where(cmp, left_norm, right_pair) + left_idx_merge = ops.where(cmp, left_idx_norm, right_idx_pair) + right_merge = ops.where(cmp, right_pair, left_norm) + right_idx_merge = ops.where(cmp, right_idx_pair, left_idx_norm) + return left_merge, left_idx_merge, right_merge, right_idx_merge + + +class MLIRSortTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, dim, descending=False, stable=False, input_reorder=None): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.dim = dim + self.descending = descending + self.stable = stable + self.use_stable_sort = False + self.output_nodes = [ + Buffer(name="buf_out_values", layout=layout), + ] + self.output_node = self.output_nodes[0] + + def render( + self, + kernel: MLIRTemplateKernel, + template_buffer_node=None, + epilogue_nodes: Optional[List[IRNode]] = None, + tile_info=None, + **kwargs, + ): + if template_buffer_node is not None: + self.output_nodes[0] = template_buffer_node + self.output_node = template_buffer_node + + x = self.input_nodes[0] + xi = self.input_nodes[1] + yv = self.output_nodes[0] + # XI is updated in-place by the sort kernel, so mark it as an inout arg. + kernel.kernel_group.args.make_inplace(xi.get_name(), xi.get_name()) + sort_size = int(x.get_size()[self.dim]) + vector_size = VECTOR_SIZE + if sort_size <= 0: + raise NotImplementedError("Sort size must be > 0") + if sort_size < vector_size or sort_size % vector_size != 0: + raise NotImplementedError( + f"Sort size must be a multiple of vector size (sort_size={sort_size}, vector_size={vector_size})" + ) + num_chunks = sort_size // vector_size + if num_chunks & (num_chunks - 1): + raise NotImplementedError( + f"Loop-based bitonic chunk merge requires power-of-two chunk count (num_chunks={num_chunks})" + ) + + # --- N-D generalization: outer loops over all non-sort dims --- + rank = len(x.get_size()) + sort_dim = self.dim if self.dim >= 0 else self.dim + rank + if sort_dim < 0 or sort_dim >= rank: + raise NotImplementedError(f"Invalid sort dim for rank-{rank} tensor (dim={self.dim})") + x_layout = x.get_layout() + xi_layout = xi.get_layout() + yv_layout = yv.get_layout() + + if rank == 1: + # Edge case for 1D tensor + output_sizes = [1] + output_dim = [0] + step_sizes = [1] + tile_sizes = [1, sort_size] + x_dram_stride = [int(x_layout.stride[sort_dim]), int(x_layout.stride[sort_dim])] + xi_dram_stride = [int(xi_layout.stride[sort_dim]), int(xi_layout.stride[sort_dim])] + yv_dram_stride = [int(yv_layout.stride[sort_dim]), int(yv_layout.stride[sort_dim])] + template_rank = 2 + else: + output_sizes = [sz for d, sz in enumerate(yv.get_size()) if d != sort_dim] + output_dim = [d for d, _ in enumerate(yv.get_size()) if d != sort_dim] + step_sizes = [1] * len(output_sizes) + + tile_dim = max(output_dim, key=lambda d: int(yv.get_size()[d])) + tile_sizes = [min(kernel.vector_lane, int(yv.get_size()[tile_dim])), sort_size] + step_sizes[output_dim.index(tile_dim)] = tile_sizes[0] + + x_dram_stride = [int(x_layout.stride[tile_dim]), int(x_layout.stride[sort_dim])] + xi_dram_stride = [int(xi_layout.stride[tile_dim]), int(xi_layout.stride[sort_dim])] + yv_dram_stride = [int(yv_layout.stride[tile_dim]), int(yv_layout.stride[sort_dim])] + template_rank = rank + + x_offset_map = _make_offset_map(output_dim, x_layout.stride, x_layout.offset) + xi_offset_map = _make_offset_map(output_dim, xi_layout.stride, xi_layout.offset) + yv_offset_map = _make_offset_map(output_dim, yv_layout.stride, yv_layout.offset) + outer_vars = ", ".join(f"%index{d}" for d in output_dim) + + # indent for DMA ops = 2 (inside func) + 2 per outer loop + indent_size = 2 + len(output_dim) * 2 + 4 + + vlane_stride = 1 + vlane_split_axis = 0 + x_tile_desc = mlir_common.MLIRMultiDimTile(tile_sizes, kernel.vector_lane, vlane_split_axis, vlane_stride) + x_tile_desc.set_tile_size_stride(tile_sizes, [sort_size, 1]) + x_tile_desc.set_name("X_buffer") + x_tile_desc.offset = x_layout.offset + + xi_tile_desc = mlir_common.MLIRMultiDimTile(tile_sizes, kernel.vector_lane, vlane_split_axis, vlane_stride) + xi_tile_desc.set_tile_size_stride(tile_sizes, [sort_size, 1]) + xi_tile_desc.set_name("XI_buffer") + xi_tile_desc.offset = xi_layout.offset + + yv_tile_desc = mlir_common.MLIRMultiDimTile(tile_sizes, kernel.vector_lane, vlane_split_axis, vlane_stride) + yv_tile_desc.set_tile_size_stride(tile_sizes, [sort_size, 1]) + yv_tile_desc.set_name("YV_buffer") + yv_tile_desc.offset = yv_layout.offset + + data_stype = mlir_common.DTYPE_TO_MLIR[x.get_dtype()] + idx_stype = mlir_common.DTYPE_TO_MLIR[xi.get_dtype()] + + elem_memref_t = f"memref<1x{sort_size}x{data_stype}, 1>" + rev_indices = list(range(vector_size - 1, -1, -1)) + + bitonic_body = mlir_common.ParallelLoopBuffer(initial_indent=2) + bitonic_body.tabwidth = 2 + # 1) Local SIMD sort per chunk. + init_cse = common.CSE(kernel.newvar_prefix, kernel.suffix, name_prefix="sort_init") + with kernel, kernel.override_buffer_cse(buffer=bitonic_body, cse=init_cse): + bitonic_body.writelines(LoopLevel("chunk", num_chunks).lines()) + with bitonic_body.indent(attribute="{inner_loop=true}"): + bitonic_body.writeline("%elem = affine.apply #map_chunk_to_elem(%chunk)") + x_chunk = ops._load( + vector_size, + data_stype, + "X_buffer", + "%t_const0, %elem", + x_tile_desc.get_mlir_shape(data_stype), + ) + idx_step_index = kernel.register_var_cse("idx_step_index", vector_size, "index") + bitonic_body.writeline(f"%{idx_step_index} = vector.step : vector<{vector_size}xindex>") + idx_step = ops.index_cast(idx_step_index, idx_stype) + idx_base = kernel.register_var_cse("idx_base", 1, idx_stype) + bitonic_body.writeline(f"%{idx_base} = arith.index_cast %elem : index to {idx_stype}") + idx_base_vec = ops.broadcast(idx_base, vector_size) + idx_chunk = ops.add(idx_base_vec, idx_step) + yv_chunk, yi_chunk = _bitonic_sort_pair( + x_chunk, idx_chunk, vector_size, descending=self.descending, stable_sort=self.use_stable_sort + ) + ops._store( + yv_chunk, + "YV_buffer", + "%t_const0, %elem", + yv_tile_desc.get_mlir_shape(data_stype), + ) + ops._store( + yi_chunk, + "XI_buffer", + "%t_const0, %elem", + xi_tile_desc.get_mlir_shape(idx_stype), + ) + + # 2) Chunk-level bitonic merge (loop form). + stage = 0 + k = 2 + while k <= num_chunks: + j = k // 2 + while j >= 1: + for block_start, is_even_block in ((0, True), (k, False)): + if block_start >= num_chunks: + continue + asc_dir = is_even_block if not self.descending else (not is_even_block) + stage_cse = common.CSE(kernel.newvar_prefix, kernel.suffix, name_prefix=f"sort_stage_{stage}") + with kernel, kernel.override_buffer_cse(buffer=bitonic_body, cse=stage_cse): + stage_loops = [ + LoopLevel("base", num_chunks, start=block_start, step=2 * k), + LoopLevel("p", k, step=2 * j), + LoopLevel("q", j), + ] + with contextlib.ExitStack() as stack: + for loop in stage_loops: + bitonic_body.writelines(loop.lines()) + stack.enter_context(bitonic_body.indent(attribute="{inner_loop=true}")) + + bitonic_body.writeline( + f"%left_elem = affine.apply affine_map<(d0, d1, d2) -> ((d0 + d1 + d2) * {vector_size})>(%base, %p, %q)" + ) + bitonic_body.writeline( + f"%right_elem = affine.apply affine_map<(d0, d1, d2) -> ((d0 + d1 + d2 + {j}) * {vector_size})>(%base, %p, %q)" + ) + + left_vec = ops._load( + vector_size, + data_stype, + "YV_buffer", + "%t_const0, %left_elem", + yv_tile_desc.get_mlir_shape(data_stype), + ) + right_vec = ops._load( + vector_size, + data_stype, + "YV_buffer", + "%t_const0, %right_elem", + yv_tile_desc.get_mlir_shape(data_stype), + ) + left_idx = ops._load( + vector_size, + idx_stype, + "XI_buffer", + "%t_const0, %left_elem", + xi_tile_desc.get_mlir_shape(idx_stype), + ) + right_idx = ops._load( + vector_size, + idx_stype, + "XI_buffer", + "%t_const0, %right_elem", + xi_tile_desc.get_mlir_shape(idx_stype), + ) + norm_desc = not asc_dir + left_norm, left_idx_norm = _bitonic_sort_pair( + left_vec, left_idx, vector_size, descending=norm_desc, stable_sort=self.use_stable_sort + ) + right_norm, right_idx_norm = _bitonic_sort_pair( + right_vec, right_idx, vector_size, descending=norm_desc, stable_sort=self.use_stable_sort + ) + left_merge, left_idx_merge, right_merge, right_idx_merge = _merge_sorted_pair_vectors( + left_norm, + left_idx_norm, + right_norm, + right_idx_norm, + ascending=asc_dir, + stable_sort=self.use_stable_sort, + vector_size=vector_size, + rev_indices=rev_indices, + ) + left_new, left_idx_new = _bitonic_sort_pair( + left_merge, left_idx_merge, vector_size, descending=norm_desc, stable_sort=self.use_stable_sort + ) + right_new, right_idx_new = _bitonic_sort_pair( + right_merge, right_idx_merge, vector_size, descending=norm_desc, stable_sort=self.use_stable_sort + ) + ops._store( + left_new, + "YV_buffer", + "%t_const0, %left_elem", + yv_tile_desc.get_mlir_shape(data_stype), + ) + ops._store( + right_new, + "YV_buffer", + "%t_const0, %right_elem", + yv_tile_desc.get_mlir_shape(data_stype), + ) + ops._store( + left_idx_new, + "XI_buffer", + "%t_const0, %left_elem", + xi_tile_desc.get_mlir_shape(idx_stype), + ) + ops._store( + right_idx_new, + "XI_buffer", + "%t_const0, %right_elem", + xi_tile_desc.get_mlir_shape(idx_stype), + ) + stage += 1 + j //= 2 + k *= 2 + + kernel.render_options = dict( + KERNEL_NAME=self.name, + NAMES_STR="X, XI, YV", + kernel=kernel, + X=x, + XI=xi, + YV=yv, + X_TILE_DESC=x_tile_desc, + XI_TILE_DESC=xi_tile_desc, + YV_TILE_DESC=yv_tile_desc, + SORT_SIZE=sort_size, + VECTOR_SIZE=vector_size, + DATA_STYPE=data_stype, + IDX_STYPE=idx_stype, + ELEM_MEMREF_T=elem_memref_t, + BITONIC_BODY=bitonic_body.getvalue().rstrip(), + input_reorder=self.input_reorder, + # N-D generalization + RANK = template_rank, + OUTPUT_SIZES = output_sizes, + OUTPUT_DIM = output_dim, + STEP_SIZES = step_sizes, + OUTER_VARS = outer_vars, + X_OFFSET_MAP = x_offset_map, + XI_OFFSET_MAP = xi_offset_map, + YV_OFFSET_MAP = yv_offset_map, + X_DRAM_STRIDE = x_dram_stride, + XI_DRAM_STRIDE = xi_dram_stride, + YV_DRAM_STRIDE = yv_dram_stride, + INDENT_SIZE = indent_size, + ) + code = self._template_from_string(TEMPLATE).render(**kernel.render_options) + return code + + +class MLIRStableSortTemplate(MLIRSortTemplate): + def __init__(self, input_nodes, layout, dim, descending=False, stable=True, input_reorder=None): + super().__init__( + input_nodes=input_nodes, + layout=layout, + dim=dim, + descending=descending, + stable=stable, + input_reorder=input_reorder, + ) + self.use_stable_sort = True diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index e493464a..c8fc036f 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -13,8 +13,9 @@ from typing import List, Optional from unittest.mock import patch -from torch._inductor.codegen.common import KernelTemplate, ChoiceCaller, CSE, DeferredLine -from torch._inductor.ir import Buffer, IRNode, TemplateBuffer +from PyTorchSimFrontend import extension_config +from torch._inductor.codegen.common import KernelTemplate, CSE, DeferredLine +from torch._inductor.ir import Buffer, IRNode, TemplateBuffer, ChoiceCaller, ir_node_to_tensor from torch._inductor.select_algorithm import PartialRender from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller from torch._inductor.autotune_process import TensorMeta @@ -25,13 +26,15 @@ import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend.mlir.mlir_autotune import MLIRBenchmarkRequest from PyTorchSimFrontend.mlir.mlir_common import BaseMLIRHardwareInfo -from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel, reduction_init, reduction_partial_combine_vec, reduction_combine_vec, is_welford_reduction +from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel, reduction_init, reduction_partial_combine_vec, is_welford_reduction from PyTorchSimFrontend.mlir.mlir_scheduling import SchedulerNode from torch._inductor.codegen import common -from PyTorchSimFrontend import extension_config from . import mlir_common +# Configure logger for mlir_template module +logger = extension_config.setup_logger() + class IndentedBufferGroup: def __init__(self, kernel: 'MLIRTemplateKernel', prefix=""): self.kernel = kernel @@ -85,7 +88,8 @@ def as_local(self): } try: self.set_buffers() - yield self + with self.kernel.override_buffer_cse(buffer=self.compute, cse=self.cse): + yield self finally: self.restore_buffers() @@ -108,7 +112,8 @@ def __init__(self, self.outer_func_name = outer_func_name self.outer_func_render = outer_func_render self.kernel_arg_attributes = kernel_arg_attributes - self.render_hooks = OrderedDict() + self.render_hooks = OrderedDict() # Stores {key: (priority, hook)} + self.dma_op_counter = itertools.count() # Add counter for unique DMA op keys self.buffer_names = dict() self.render_options = dict() self.tile_size = [] @@ -120,6 +125,7 @@ def __init__(self, self.epilogue_buffer_group = IndentedBufferGroup(self, prefix="epilogue_") self.global_vars = IndentedBuffer() self.exception_nodes = {} + self.epilogue_info = {} # Reduction data structure self.reduction_epilogue_suffix = IndentedBuffer() self.reduction_fusion = False @@ -144,10 +150,10 @@ def add_loop_info(self, mat_size, tile_size): for idx, (loop_size, stride) in enumerate(zip(mat_size, tile_size)): self.loop_info[f"index{idx}"] = [0, loop_size, stride] - def gemmini_gemm_mapping(self, M, N, K): + def gemmini_gemm_mapping(self, M, N, K, precision_bytes=4): spad_size = self.spad_info["spad_size"] * self.vector_lane num_cores = self.num_cores - precision = self.precision + precision = precision_bytes dim_I, dim_J, dim_K = M, N, K dim = self.vector_lane @@ -199,7 +205,7 @@ def gemmini_gemm_mapping(self, M, N, K): return inner_I, inner_J, inner_K - def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, pad_k=True, min_tile=False, is_conv=False): + def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, pad_k=True, min_tile=False, is_conv=False, precision_bytes=4): tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane @@ -227,11 +233,11 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p tile_M = i * self.vector_lane if M > self.vector_lane else M_padded for j in tile_N_range: tile_N = j * self.vector_lane if N > self.vector_lane else N_padded - used_spad_size = (tile_M * tile_K * (1 + n_prologue_node) + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * self.precision + used_spad_size = (tile_M * tile_K * (1 + n_prologue_node) + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * precision_bytes weight_size_per_lane = self.get_spad_size_per_lane(tile_K, tile_N) input_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_prologue_node), tile_K) output_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_extra_node), tile_N) - used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * precision_bytes check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) if check_spad_size: dir_path = f"{extension_config.CONFIG_TORCHSIM_DIR}/validation/gemm_candidates" @@ -253,11 +259,11 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p tile_M = i * self.vector_lane if M > self.vector_lane else M_padded for j in tile_N_range: tile_N = j * self.vector_lane if N > self.vector_lane else N_padded - used_spad_size = (tile_M * tile_K * (1 + n_prologue_node) + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * self.precision + used_spad_size = (tile_M * tile_K * (1 + n_prologue_node) + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * precision_bytes weight_size_per_lane = self.get_spad_size_per_lane(tile_K, tile_N) input_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_prologue_node), tile_K) output_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_extra_node), tile_N) - used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * precision_bytes n_tile = math.ceil(M / max(tile_M, 128)) * math.ceil(N / max(tile_N, 128)) check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) if check_spad_size and max_used_spad_size < used_spad_size and maximize_i_j <= tile_M * tile_N and n_tile >= minimum_n_tile and max(tile_N, 128) // max(tile_M, 128) < 10: @@ -271,7 +277,7 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p tile_candidates = [v for _, v in tile_candidates] return tile_candidates - def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0, precision_bytes=4): tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane @@ -279,7 +285,7 @@ def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation max_spad_per_lane = spad_size_per_lane // 2 # double buffer max_used_spad_size = 0 - M, N, K = self.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True)[0] + M, N, K = self.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True, precision_bytes=precision_bytes)[0] max_k_h_w = 1 # maximize kernel size max_o_h_w = 1 # maximize output size K = min(K, self.vector_lane) @@ -292,11 +298,11 @@ def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation weight_size = k_w * k_h * K * N input_size = i_w * i_h * M * K output_size = o_w * o_h * M * N - used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * self.precision + used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * precision_bytes weight_size_per_lane = self.get_spad_size_per_lane(k_w * k_h * K, N) input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * M, K) output_size_per_lane = self.get_spad_size_per_lane(o_w * o_h * M * (1 + n_extra_node), N) - used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * precision_bytes check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) if check_spad_size: tile_candidates.append((used_spad_size, (k_h, k_w, o_h, o_w, M, N, K))) @@ -312,7 +318,7 @@ def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation tile_candidates = [v for _, v in tile_candidates] return tile_candidates - def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0, precision_bytes=4): tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane @@ -320,7 +326,7 @@ def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, max_spad_per_lane = spad_size_per_lane // 2 max_used_spad_size = 0 - M, N, K = self.gemm_combination_mapping(M, N, K * K_W, n_extra_node=n_extra_node, pad_k=False, is_conv=True)[0] + M, N, K = self.gemm_combination_mapping(M, N, K * K_W, n_extra_node=n_extra_node, pad_k=False, is_conv=True, precision_bytes=precision_bytes)[0] max_k_h_w = K_W for o_h in sympy.divisors(O_H): for o_w in sympy.divisors(O_W): @@ -330,11 +336,11 @@ def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, weight_size = 1 * k_h * K * N input_size = i_w * i_h * M * K output_size = o_w * o_h * M * N - used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * self.precision + used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * precision_bytes weight_size_per_lane = self.get_spad_size_per_lane(1 * k_h * K, N) input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * M, K) output_size_per_lane = self.get_spad_size_per_lane(o_w * o_h * M * (1 + n_extra_node), N) - used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * precision_bytes check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) if check_spad_size: tile_candidates.append((used_spad_size, (k_h, K_W, o_h, o_w, M, N, K))) @@ -348,7 +354,7 @@ def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, tile_candidates = [v for _, v in tile_candidates] return tile_candidates - def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0, precision_bytes=4): tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane @@ -356,7 +362,7 @@ def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilatio max_spad_per_lane = spad_size_per_lane // 2 max_used_spad_size = 0 - M, N, K = self.gemm_combination_mapping(O_W, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True)[0] + M, N, K = self.gemm_combination_mapping(O_W, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True, precision_bytes=precision_bytes)[0] max_k_h_w = 1 for o_h in sympy.divisors(O_H): for k_h in sympy.divisors(K_H): @@ -366,11 +372,11 @@ def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilatio weight_size = k_w * k_h * K * N input_size = i_w * i_h * k_w * K output_size = M * o_h * N - used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * self.precision + used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * precision_bytes weight_size_per_lane = self.get_spad_size_per_lane(k_w * k_h * K, N) input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * k_w, K) output_size_per_lane = self.get_spad_size_per_lane(M * o_h * (1 + n_extra_node), N) - used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * precision_bytes check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) if check_spad_size: tile_candidates.append((used_spad_size, (k_h, k_w, o_h, M, M, N, K))) @@ -383,9 +389,102 @@ def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilatio tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) tile_candidates = [v for _, v in tile_candidates] return tile_candidates + + # Flash Attention requires more SRAM compared to standard GEMM. + # Total buffers needed: query, key, value, out, mul, max, sum + # Tensor Shapes: + # query (tile_l, tile_e), key (tile_s, tile_e), value (tile_s, tile_e), mul (tile_s, tile_l), out(tile_l, tile_e) + # max, sum : (tile_l, 2) + def flash_sdpa_mapping(self, l, s, e, n_extra_node=0, n_prologue_node=0, pad_e=True, min_tile=False, is_conv=False): + tile_candidates = [] + + spad_size_per_lane = self.spad_info["spad_size"] + spad_size = spad_size_per_lane * self.vector_lane + + # Double buffering + max_spad_per_lane = spad_size_per_lane // 2 + max_spad_size = spad_size // 2 + + # Padding for utilization + minimum_tile_size = 8 + minimum_n_tile = self.num_cores if min_tile else 1 + l_pad_factor = self.vector_lane if l > self.vector_lane else minimum_tile_size + s_pad_factor = self.vector_lane if s > self.vector_lane else minimum_tile_size + + pad = lambda x, factor: ((x + factor - 1) // factor) * factor + l_padded = pad(l, l_pad_factor) + s_padded = pad(s, s_pad_factor) + + # Calculate the total number of vector-sized blocks + l_idx = l_padded // self.vector_lane + s_idx = s_padded // self.vector_lane + + # Generate candidates for the number of blocks per tile + l_tile_range = sympy.divisors(l_idx) if l > self.vector_lane else [1] + s_tile_range = sympy.divisors(s_idx) if s > self.vector_lane else [1] + + # Convert block count to actual tile size + maximize_i_j = 1 + max_used_spad_size = 0 + + # Flash Attention does not tile along the head dimension (e or ev). + tile_e = e + + for i in l_tile_range: + tile_l = i * self.vector_lane if l > self.vector_lane else l_padded + for j in s_tile_range: + tile_s = j * self.vector_lane if s > self.vector_lane else s_padded + + # Calculate used spad size + used_spad_size = ( + tile_l * tile_e * (1 + n_prologue_node) # query + + tile_s * tile_e # key + + tile_s * tile_e # value + + tile_s * tile_l # mul + + tile_l * tile_e * (1 + n_extra_node) # out + + (tile_l * 2) * 2 # max, sum + ) * self.precision + + # Calculate used spad size per lane. + query_per_lane = tile_e * (1+n_prologue_node) + key_per_lane = tile_s + value_per_lane = tile_e + mul_per_lane = tile_s + out_per_lane = tile_e * (1 + n_extra_node) + vec_per_lane = 2 * 2 + + used_spad_per_lane = ( + query_per_lane + + key_per_lane + + value_per_lane + + mul_per_lane + + out_per_lane + + vec_per_lane + ) * self.precision + + # Add the validated candidate to the list if it passes all hardware constraints. + n_tile = math.ceil(l / max(tile_l, 128)) * math.ceil(s / max(tile_s, 128)) + check_spad_size = (used_spad_size < max_spad_size and used_spad_per_lane < max_spad_per_lane) + + if (check_spad_size + and max_used_spad_size < used_spad_size # SRAM utilization + and maximize_i_j <= tile_l * tile_s # Larger tile + and n_tile >= minimum_n_tile # Pallelism + and max(tile_s, 128) // max(tile_l, 128) < 10): # Balanced Shape + max_used_spad_size = used_spad_size + maximize_i_j = tile_l * tile_s + + if check_spad_size: + tile_candidates.append((used_spad_size, (tile_l, tile_s, tile_e))) + + # Sort by used_spad_size. + # tile_candidates[0] is the best solution we have. + tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) + tile_candidates = [v for _, v in tile_candidates] + + return tile_candidates def meta_kernel(self): - wrapper = V.graph.wrapper_code kernel_arg_attributes = self.kernel_arg_attributes _, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs() if kernel_arg_attributes is not None: @@ -393,18 +492,14 @@ def meta_kernel(self): for idx in range(len(arg_attributes)): if arg_attributes[idx][0] == name: arg_attributes[idx][1] = attr - wrapper.add_import_once('\nprint(f\'Wrapper Codegen Path = {__file__}\')') - # Dump loop and load/store information - wrapper.add_import_once(f"loop_info = {self.loop_info}") - wrapper.add_import_once(f"arg_attributes = {arg_attributes}") + return arg_attributes def call_kernel(self, kernel_name): wrapper = V.graph.wrapper_code _, call_args, _, _ = self.kernel_group.args.mlir_argdefs() # generate the code to call this wrapper.generate_kernel_call( - kernel_name if self.outer_func_name is None else self.outer_func_name + f"_{len(call_args)}", - call_args, cuda=False) + kernel_name if self.outer_func_name is None else "wrapper_" + kernel_name, call_args) def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_nodes, tile_info): with self as kernel: @@ -430,7 +525,7 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_ ).group prologue_tile_desc = kernel.set_tile_size(kernel.prologue_info, prologue=True) kernel.kernel_group.set_tile_info(prologue_tile_desc) - vars, reduction_vars = kernel.set_ranges(group, reduction_group) + vars, reduction_vars = kernel.set_ranges(group, reduction_group, list(self.dim_aliasing.values())) for node in prologue_nodes: # Reuse created spad read_list = sorted([i.name for i in node.read_writes.reads]) @@ -461,24 +556,24 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_ } node.codegen((vars, reduction_vars)) - # Codegen epilogue nodes - tile_desc = kernel.set_tile_size(kernel.epilogue_info) - kernel.kernel_group.set_tile_info(tile_desc) - kernel.call_ranges = None if epilogue_nodes: + # Codegen epilogue nodes + tile_desc = kernel.set_tile_size(kernel.epilogue_info) + kernel.kernel_group.set_tile_info(tile_desc) + kernel.call_ranges = None with kernel.epilogue_buffer_group.as_local(): _, (group, reduction_group) = max( epilogue_nodes, key=lambda x: int(x.is_reduction()) ).group - vars, reduction_vars = kernel.set_ranges(group, reduction_group) + vars, reduction_vars = kernel.set_ranges(group, reduction_group, list(self.dim_aliasing.values())) for node in epilogue_nodes: node.codegen((vars, reduction_vars)) - with V.set_kernel_handler(kernel): + with self as kernel: src_code = ( partial_code if isinstance(partial_code, str) - else partial_code.finalize() + else partial_code.finalize_all() ) # For consistency, white space could make wrong write_path @@ -486,52 +581,54 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_ buffer.splice(src_code) src_code = buffer.getvalue() self._prepare_simulator_headers(src_code) - return src_code + meta_code = self.meta_kernel() + return src_code, meta_code def make_choices(self, tile_candidates, render, template_node, prologue_nodes, epilogue_nodes): choices = [] for tile_info in tile_candidates: - if extension_config.CONFIG_DEBUG_MODE: - # Compute Tile M, N, K DMA Tile M, N, K - print(f"[Auto-tune] Trying tile size: {list(tile_info)}") - src_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes, tile_info) + # Compute Tile M, N, K DMA Tile M, N, K + logger.debug(f"Auto-tune: Trying tile size: {list(tile_info)}") + src_code, meta_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes, tile_info) bench_runner = self.run_bench([template_node], self.kernel_name, src_code) - choices.append((bench_runner, src_code, tile_info, self.loop_size)) + choices.append((bench_runner, src_code, meta_code, tile_info, self.loop_size)) self.reset(reason=None) return choices def _log_autotune_result(self, best_choice, best_cycle): - tile_size = best_choice[2] - print( - f"[Auto-tune] Optimal tile size: {list(tile_size)}, " + tile_size = best_choice[3] + logger.debug( + f"Auto-tune: Optimal tile size: {list(tile_size)}, " f"cycles: {best_cycle}" ) def codegen_nodes(self, tile_candidates, render, template_node, prologue_nodes, epilogue_nodes): if "autotune" in extension_config.codegen_mapping_strategy and len(tile_candidates): - src_code, loop_size = self.autotune(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) + src_code, meta_code, loop_size = self.autotune(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) self.loop_size = loop_size else: tile_info = tile_candidates[0] if tile_candidates else None - src_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes, tile_info) + src_code, meta_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes, tile_info) - with V.set_kernel_handler(self): - self.meta_kernel() - return src_code + return src_code, meta_code def _prepare_simulator_headers(self, src_code): + from filelock import FileLock + spad_end_symbol = f"int spad_end[0] __attribute__ ((section(\".spad\")));\n" spad_section_end_symbol = f"int spad_section_end[0] __attribute__ ((section(\".spad\"), aligned({self.spad_info['spad_size']*self.vector_lane})));" write_path = extension_codecache.get_write_path(src_code) - if not os.path.exists(write_path): - os.makedirs(write_path, exist_ok=True) + os.makedirs(write_path, exist_ok=True) spike_write_path = os.path.join(write_path, "global_var.h") gem5_write_path = os.path.join(write_path, "gem5_global_var.h") - if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, self.header.getvalue()+spad_end_symbol+spad_section_end_symbol) - if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, self.gem5_header.getvalue()) + + lock = FileLock(extension_codecache.get_lock_path(write_path), timeout=extension_codecache.LOCK_TIMEOUT) + with lock: + if not os.path.exists(spike_write_path): + write_atomic(spike_write_path, self.header.getvalue()+spad_end_symbol+spad_section_end_symbol) + if not os.path.exists(gem5_write_path): + write_atomic(gem5_write_path, self.gem5_header.getvalue()) def codegen_prologue_body(self): body = IndentedBuffer() @@ -557,7 +654,7 @@ def template_store(): dram_var = self.epilogue_info["dram_var"] index_list = self.epilogue_info["dram_idx"] tile_desc = self.epilogue_info["dram_tile_desc"] - code = self.def_dma_op("MVOUT", dram_var, index_list, tile_desc) + code = self.def_dma_op("MVOUT", dram_var, index_list, tile_desc, lazy_mode=False) self.cse.generate(self.dma_stores, code, assignment = False) body = IndentedBuffer() @@ -576,8 +673,8 @@ def template_store(): with contextlib.ExitStack() as stack: stack.enter_context(compute_body.indent(attribute="{inner_loop=false}",suffix=self.compute_body_loop.epilogue_line())) if self.reduction_fusion: - compute_body.writelines(self.reduction_body_loop.lines()) compute_body.splice(self.masks) + compute_body.writelines(self.reduction_body_loop.lines()) stack.enter_context(compute_body.indent(attribute="{inner_loop=false}")) compute_body.splice(self.loads) compute_body.splice(self.compute) @@ -628,14 +725,34 @@ def def_kernel( extra_node[node.get_name()] = node.node else: extra_node[node.get_name()] = node - self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] + + if 'sram_var' in self.epilogue_info: + self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] def hook(): - arg_defs, *_ = self.kernel_group.args.mlir_argdefs(extra_node=extra_node) - return f"({', '.join(arg_defs)})" + arg_defs, call_args, *_ = self.kernel_group.args.mlir_argdefs(extra_node=extra_node) + output_names = names[len(inputs) : len(inputs) + len(outputs)] + out_ptr_idx = 0 + renamed_arg_defs = [] + for outer, arg_def in zip(call_args, arg_defs): + raw_symbol = arg_def.split(":", 1)[0].strip().lstrip("%") + if outer in self.kernel_group.args.input_buffers: + symbol = self.kernel_group.args.input_buffers[outer] + elif outer in self.kernel_group.args.output_buffers: + symbol = self.kernel_group.args.output_buffers[outer] + elif raw_symbol.startswith("out_ptr") and out_ptr_idx < len(output_names): + symbol = output_names[out_ptr_idx] + out_ptr_idx += 1 + elif outer in self.kernel_group.args.sizevars: + symbol = self.kernel_group.args.sizevars[outer] + else: + symbol = raw_symbol + _, arg_type = arg_def.split(":", 1) + renamed_arg_defs.append(f"%{symbol}:{arg_type}") + return f"({', '.join(renamed_arg_defs)})" assert "" not in self.render_hooks - self.render_hooks[""] = hook + self.render_hooks[""] = (5, hook) # Default priority 5 return "" # This function is a temporal function for convolution because currently convolution kernel is not considering padding. @@ -673,7 +790,8 @@ def def_conv_kernel( self.kernel_group.args.output_buffers[node.get_name()] = name self.store_buffer_names.add(node.get_name()) #TODO: Is this enough not calling store() in mlir_common.py? self.extra_node[node.get_name()] = node - self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] #TODO: Buffer name fixed + if 'sram_var' in self.epilogue_info: + self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] #TODO: Buffer name fixed def kernel_hook(): arg_defs, *_ = self.kernel_group.args.mlir_argdefs(extra_node=self.extra_node) @@ -681,7 +799,7 @@ def kernel_hook(): return f"({', '.join(arg_defs)})" assert "" not in self.render_hooks - self.render_hooks[""] = kernel_hook + self.render_hooks[""] = (5, kernel_hook) # Default priority 5 return "" # This function is for convolution wrapper function finalizing. @@ -692,7 +810,7 @@ def wrapper_hook(): return f"({', '.join(wrapper_arg_defs)})" if "" not in self.render_hooks: - self.render_hooks[""] = wrapper_hook + self.render_hooks[""] = (5, wrapper_hook) # Default priority 5 return "" def get_conv_inputs(self): @@ -701,15 +819,15 @@ def get_conv_inputs(self): def get_conv_outputs(self): return {k: v for k, v in self.kernel_group.args.output_buffers.items() if v != 'REMOVED'} - def load_input(self, indent_size: int = 0): + def load_input(self, indent_size: int = 0, priority: int = 1): def hook(): code = IndentedBuffer() prologue_code = self.codegen_prologue_body() if prologue_code.getvalue(): input_dma_code = self.def_dma_op("MVIN", self.prologue_info["input_dram_var"], self.prologue_info["input_idx"], - self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False) + self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False, lazy_mode=False) weight_dma_code = self.def_dma_op("MVIN", self.prologue_info["weight_dram_var"], self.prologue_info["weight_idx"], - self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False) + self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False, lazy_mode=False) if (self.prologue_info["is_input_fused"]): code.splice(input_dma_code) code.splice(prologue_code) @@ -720,58 +838,63 @@ def hook(): code.splice(input_dma_code) else: dma_code = self.def_dma_op("MVIN", self.prologue_info["input_dram_var"], self.prologue_info["input_idx"], - self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False) + self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False, lazy_mode=False) code.splice(dma_code) dma_code = self.def_dma_op("MVIN", self.prologue_info["weight_dram_var"], self.prologue_info["weight_idx"], - self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False) + self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False, lazy_mode=False) code.splice(dma_code) code = textwrap.indent(code.getvalue(), " "*indent_size).strip() return code assert "" not in self.render_hooks - self.render_hooks[""] = hook - self.render_hooks.move_to_end("", last=False) # Force order to be triggered first + self.render_hooks[""] = (priority, hook) return "" - def store_output(self, indent_size: int = 0): + def store_output(self, indent_size: int = 0, priority: int = 1): def hook(): epilogue_code = self.codegen_epilogue_body() return textwrap.indent(epilogue_code.getvalue(), " "*indent_size).strip() assert "" not in self.render_hooks - self.render_hooks[""] = hook - self.render_hooks.move_to_end("", last=False) # Force order to be triggered first + self.render_hooks[""] = (priority, hook) return "" - def reduction_output(self, indent_size: int = 0): + def reduction_output(self, indent_size: int = 0, priority: int = 5): def hook(): return textwrap.indent(self.reductions_suffix.getvalue(), " "*indent_size).strip() assert "" not in self.render_hooks - self.render_hooks[""] = hook + self.render_hooks[""] = (priority, hook) return "" + def _sort_hooks_by_priority(self): + """Sort hooks by priority (lower priority executes first).""" + sorted_hooks = OrderedDict() + for key, (priority, hook) in sorted(self.render_hooks.items(), key=lambda x: x[1][0]): + sorted_hooks[key] = hook + return sorted_hooks + def def_function(self): - _, call_args, _ = self.kernel_group.args.python_argdefs() + _, call_args, _, _ = self.kernel_group.args.python_argdefs() if self.outer_func_render is not None: partial_code, function_name = self.outer_func_render(input_args=call_args) + return PartialRender( partial_code, - self.render_hooks, + self._sort_hooks_by_priority(), ), function_name else: return None, None - def def_global_vars(self): + def def_global_vars(self, priority: int = 10): key = "" def hook(): return textwrap.indent(self.global_vars.getvalue(), "").strip() - assert key not in self.render_hooks - self.render_hooks[key] = hook + self.render_hooks[key] = (priority, hook) return key - def def_local_vars(self, indent_size=0): + def def_local_vars(self, indent_size=0, priority: int = 10): key = "" def hook(): code = IndentedBuffer() @@ -780,57 +903,89 @@ def hook(): code.splice(self.alloc_buffer) return textwrap.indent(code.getvalue(), " "*indent_size).strip() - assert key not in self.render_hooks - self.render_hooks[key] = hook + self.render_hooks[key] = (priority, hook) return key def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_common.MLIRMultiDimTile, - subtile_size:list=[], async_type=None, indent_size=0): - # Prepare code block - local_code = IndentedBuffer() - with V.set_kernel_handler(self): - index_var = self.parse_index_list(index_list, local_code, offset=tile_desc.offset) - node_layout = self.named_nodes[dram_var].get_layout() - if dram_var in self.exception_nodes: - numel = self.exception_nodes[dram_var]["numel"] - else: - numel = self.get_arg_info(self.named_nodes[dram_var].get_name()).get_numel() - mlir_dtype = mlir_common.DTYPE_TO_MLIR[node_layout.dtype] - dram_shape = f"memref<{numel}x{mlir_dtype}>" - dram_stride = [] - for idx in index_list: - if idx.is_Mul: - dram_stride.append(int(idx.args[0])) - elif idx == sympy.Symbol("c0"): - dram_stride.append(0) - elif not idx.is_Number: - dram_stride.append(1) + subtile_size:list=[], async_type=None, indent_size=0, priority: int = 5, lazy_mode: bool = True, + dram_stride:list=None, dram_offset=None, padding: int = 0): + # Todo. Remove legacy behavior (i.e., index_list parsing) + def generate_dma_code(): + """Internal method to generate DMA code directly.""" + local_code = IndentedBuffer() + with self, self.override_buffer_cse(buffer=local_code, cse=self.apply_cse): + if dram_offset is not None: + # Use explicitly provided offset (pre-computed MLIR SSA variable name) + index_var = dram_offset + else: + index_var = self.parse_index_list(index_list, offset=tile_desc.offset) + node_layout = self.named_nodes[dram_var].get_layout() + if dram_var in self.exception_nodes: + numel = self.exception_nodes[dram_var]["numel"] else: - dram_stride.append(0) - - sram_var = tile_desc.get_name() - tile_shape = tile_desc.get_mlir_shape(mlir_dtype) - tile_stride = tile_desc.get_tile_stride() - vlane_split_axis = tile_desc.vmap.vlane_split_axis - vlane_stride = tile_desc.vmap.vlane_stride - - zero_cse = self.get_const_cse(0, "index") - sram_index_var = ", ".join([f"%{str(zero_cse)}"]*tile_desc.get_nr_dim()) - - attribute_parts = [f"dram_stride={dram_stride}", f"sram_stride={tile_stride}", "padding=0"] - if subtile_size: - attribute_parts.append(f"subtile_size={subtile_size}, async={int(async_type) if async_type is not None else 1}") - attribute = " {" + ", ".join(attribute_parts) + "}" - code = self.get_dma_code(dma_type, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, "") - local_code.writeline(code) - local_code.writeline(attribute) - return textwrap.indent(local_code.getvalue(), " "*indent_size).strip() + numel = self.get_arg_info(self.named_nodes[dram_var].get_name()).get_numel() + mlir_dtype = mlir_common.DTYPE_TO_MLIR[node_layout.dtype] + dram_shape = f"memref<{numel}x{mlir_dtype}>" + + if dram_stride is not None: + # Use explicitly provided dram_stride + _dram_stride = dram_stride + else: + # Extract dram_stride from index_list (legacy behavior) + _dram_stride = [] + for idx in index_list: + if idx.is_Mul: + _dram_stride.append(int(idx.args[0])) + elif idx == sympy.Symbol("c0"): + _dram_stride.append(0) + elif not idx.is_Number: + _dram_stride.append(1) + else: + _dram_stride.append(0) + + sram_var = tile_desc.get_name() + tile_shape = tile_desc.get_mlir_shape(mlir_dtype) + sram_strides = tile_desc.get_tile_stride() + vlane_split_axis = tile_desc.vmap.vlane_split_axis + vlane_stride = tile_desc.vmap.vlane_stride + + zero_cse = self.get_const_cse(0, "index") + sram_index_var = ", ".join([f"%{str(zero_cse)}"]*tile_desc.get_nr_dim()) + + if subtile_size: + attribute = mlir_common.format_dma_op_attributes( + _dram_stride, + sram_strides, + int(padding), + subtile_size=subtile_size, + async_type=int(async_type) if async_type is not None else None, + ) + else: + attribute = mlir_common.format_dma_op_attributes(_dram_stride, sram_strides, int(padding)) + code = self.get_dma_code(dma_type, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, attribute) + local_code.writeline(code) + return textwrap.indent(local_code.getvalue(), " "*indent_size).strip() + + if not lazy_mode: + # Immediate mode: generate code directly and return it + return generate_dma_code() + + # Lazy mode: register hook and return key + dma_op_id = next(self.dma_op_counter) + key = f"" + self.render_hooks[key] = (priority, generate_dma_code) + return key def def_sram_buffer(self, dram_name, tile_desc, id=0, indent_size=0): # Prepare code block - with V.set_kernel_handler(self): - dtype = self.named_nodes[dram_name].get_layout().dtype + with self: + try: + dtype = self.named_nodes[dram_name].get_layout().dtype + except (KeyError, AttributeError, TypeError): + import torch + dtype = torch.float32 + tile_shape = tile_desc.get_mlir_shape(mlir_common.DTYPE_TO_MLIR[dtype]) buffer_name = self.allocate_sram_buffer(dtype, dram_name, tile_desc, id, forced_name=dram_name) code = f"%{tile_desc.name} = memref.get_global @{buffer_name} : {tile_shape}" @@ -843,7 +998,7 @@ def render(self, template, kwargs, define_function=None): return PartialRender( code, - self.render_hooks, + self._sort_hooks_by_priority(), ) def get_spad_size_per_lane(self, tile_m, tile_n): @@ -851,19 +1006,21 @@ def get_spad_size_per_lane(self, tile_m, tile_n): return max(size, 2) # vector load/store def load_epilogue(self, name: str, index: sympy.Expr): - index = self.rename_indexing(index) dram_var = self.kernel_group.args.input(name) dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] # Want to use tile_desc from epilogue_info - index_var = self.parse_indices(index) + with self.override_buffer_cse(buffer=self.applys, cse=self.apply_cse): + index_var = self.parse_indices(index) dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()] vlane_split_axis = self.kernel_group.tile_desc.vmap.vlane_split_axis vlane_stride = self.kernel_group.tile_desc.vmap.vlane_stride tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype) tile_stride = self.kernel_group.tile_desc.get_tile_stride() + tile_rank = self.kernel_group.tile_desc.get_nr_dim() + dram_stride = dram_stride[:tile_rank] + [0] * max(tile_rank - len(dram_stride), 0) # Compute vector unit size vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) @@ -873,7 +1030,7 @@ def load_epilogue(self, name: str, index: sympy.Expr): # Allocate sram buffer dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, self.kernel_group.tile_desc, index) - attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" + attribute = mlir_common.format_dma_op_attributes(dram_stride, tile_stride, 0) code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, dram_shape, tile_shape, attribute) self.cse.generate(self.dma_loads, code, assignment = False) @@ -885,48 +1042,39 @@ def load_epilogue(self, name: str, index: sympy.Expr): zero_var = self.get_const_cse(0) if not self.reduction_fusion: compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{self.compute_idx}"]) - if compute_vec_size > 1: - operation = "affine.vector_load" - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" - else: - operation = "affine.load" - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}" - out = self.cse.generate(self.loads, line) - self.register_var_info(out, [compute_vec_size, mlir_dtype]) + with self.override_buffer_cse(buffer=self.loads): + out = ops._load(compute_vec_size, mlir_dtype, sram_var, compute_index_var, tile_shape) else: # For reduction case reduce_size = self.reduction_nr_outer_loop vsize = compute_vec_size//reduce_size - vshape = f"vector<{vsize}x{mlir_dtype}>" if compute_vec_size > 1: - offset = self.cse.generate(self.loads, f"affine.apply affine_map<(d0, d1) -> (d0 + d1*{(self.r_tile_size)})>(%{self.compute_idx}, %{self.reduction_loop_idx})") + with self.override_buffer_cse(buffer=self.global_vars, cse=self.map_cse): + map_var = ops.affine_map(["d0", "d1"], f"d0 + d1*{(self.r_tile_size)}") + with self.override_buffer_cse(buffer=self.loads): + offset = ops.affine_apply(map_var, [self.compute_idx, self.reduction_loop_idx]) compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{offset}"]) - operation = "affine.vector_load" - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" - out = self.cse.generate(self.loads, line) - else: - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}" - out = self.cse.generate(self.loads, line) + + with self.override_buffer_cse(buffer=self.loads): + out = ops._load(vsize, mlir_dtype, sram_var, compute_index_var, tile_shape) self.register_var_info(out, [self.compute_body_loop.step, mlir_dtype]) return out def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): - index = self.rename_indexing(index) dram_var = self.kernel_group.args.output(name) dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] - index_var = self.parse_indices(index) + with self.override_buffer_cse(buffer=self.applys, cse=self.apply_cse): + index_var = self.parse_indices(index) dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()] vlane_split_axis = self.kernel_group.tile_desc.vmap.vlane_split_axis vlane_stride = self.kernel_group.tile_desc.vmap.vlane_stride tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype) tile_stride = self.kernel_group.tile_desc.get_tile_stride() - - # Compute vector unit size - vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) - compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() + tile_rank = self.kernel_group.tile_desc.get_nr_dim() + dram_stride = dram_stride[:tile_rank] + [0] * max(tile_rank - len(dram_stride), 0) if name not in self.buffer_names: sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, self.kernel_group.tile_desc, index) @@ -942,20 +1090,15 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): _, operand_type = self.var_info[value] if mlir_dtype != operand_type: - value = ops.to_dtype(value, mlir_dtype, var_info=self.var_info) + value = ops.to_dtype(value, mlir_dtype) compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{self.compute_idx}"]) # Generate vector load instruction - if compute_vec_size > 1: - operation = "affine.vector_store" - line = f"{operation} %{value}, %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" - else: - operation = "affine.store" - line = f"{operation} %{value}, %{sram_var}[{compute_index_var}] : {tile_shape}" - line = line if store_force else DeferredLine(name, line) - self.stores.writeline(line) + buffer_name = name if not store_force else None + with self.override_buffer_cse(buffer=self.stores): + ops._store(value, sram_var, compute_index_var, tile_shape, buffer_name=buffer_name) # Generate DMA instruction - attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" + attribute = mlir_common.format_dma_op_attributes(dram_stride, tile_stride, 0) code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, dram_shape, tile_shape, attribute) self.dma_stores.writeline(DeferredLine(name, code)) @@ -991,6 +1134,7 @@ def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): tile_shape = local_tile_desc.get_mlir_shape(type_name) vshape = local_tile_desc.get_mlir_vshape(type_name) + compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() name = f"{reduction_type}_buffer{self.reduction_buffer_idx}" self.reduction_buffer_idx += 1 @@ -1002,35 +1146,34 @@ def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): zero_var_list = [f"%{self.get_const_cse(0)}"] * local_tile_desc.get_nr_dim() zero_var_list[-2] = f"%{self.reduction_loop_idx}" compute_index_var = ", ".join(zero_var_list) - operation = "affine.vector_load" - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" - out = self.cse.generate(self.loads, line) - self.register_var_info(out, [self.compute_body_loop.step, type_name]) - + with self.override_buffer_cse(buffer=self.loads): + out = ops._load(vec_size, type_name, sram_var, compute_index_var, tile_shape) # Reduction body codegen - init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}") - init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {type_name} to {vshape}") - self.register_var_info(init_vec, [local_tile_desc.get_compute_vec_size(), type_name]) + with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse): + init = ops.constant(reduction_init(reduction_type, dtype), type_name) + init_vec = ops.broadcast(init, compute_vec_size) + init_vec2 = ops.broadcast(init, local_tile_desc.get_numel_per_lane()) + ops._store(init_vec2, sram_var, ", ".join([f"%{self.get_const_cse(0)}"] * local_tile_desc.get_nr_dim()), tile_shape) + mask_shape, mask_var = self.get_mask() if mask_var is not None: value = ops.where(mask_var, value, init_vec) + result = reduction_partial_combine_vec(reduction_type, value, out) # Store partial result - operation = "affine.vector_store" - line = f"{operation} %{result}, %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" - self.compute.writeline(line) # Need to be placed after partial reduction + ops._store(result, sram_var, compute_index_var, tile_shape) # Need to be placed after partial reduction self.reduction_info[sram_var] = [reduction_type, local_tile_desc] return sram_var def store_reduction_epilogue(self, name, index, value): - index = self.rename_indexing(index) dram_var = self.kernel_group.args.output(name) dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] - index_var = self.parse_indices(index, self.reductions_suffix, comments="// Store reduction") + with self.override_buffer_cse(buffer=self.reductions_suffix, cse=self.apply_cse): + index_var = self.parse_indices(index, comments="// Store reduction") dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()][:-1] # Assume that there is only one reduction axis vlane_split_axis = self.kernel_group.tile_desc.vmap.vlane_split_axis vlane_stride = self.kernel_group.tile_desc.vmap.vlane_stride @@ -1050,67 +1193,63 @@ def store_reduction_epilogue(self, name, index, value): partial_tile_shape = partial_tile_desc.get_mlir_shape(mlir_dtype) # Prepare constant - init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(self.reduction_info[value][0], dtype)} : {mlir_dtype}") + with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse): + init = ops.constant(reduction_init(self.reduction_info[value][0], dtype), mlir_dtype) + init_vec = ops.broadcast(init, partial_vec_size) + init_vec2 = ops.broadcast(init, 2) + partial_zero_var_list = [f"%{self.get_const_cse(0)}"] * partial_tile_desc.get_nr_dim() final_zero_var_list = [f"%{self.get_const_cse(0)}"] * final_tile_desc.get_nr_dim() for i in range(self.reduction_body_loop.size): # Load partial result - body_index_var = self.const_cse.generate(self.const_buffer, f"arith.constant {i} : index") - partial_zero_var_list[-2] = f"%{body_index_var}" - compute_index_var = ",".join(partial_zero_var_list) - - operation = "affine.vector_load" - line = f"{operation} %{value}[{compute_index_var}] : {partial_tile_shape}, {partial_vshape}" - out = self.cse.generate(self.reductions_suffix, line) - operation = "affine.vector_store" - init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {mlir_dtype} to {partial_vshape}") - line = f"{operation} %{init_vec}, %{value}[{compute_index_var}] : {partial_tile_shape}, {partial_vshape}" - self.reductions_suffix.writeline(line) - - # 2 step reduction - new_vec_size = 2 - new_vshape = f"vector<{partial_vec_size//new_vec_size}x{new_vec_size}x{mlir_dtype}>" - new_reduced_shape = f"vector<{new_vec_size}x{mlir_dtype}>" - out = self.cse.generate(self.reductions_suffix, f"vector.shape_cast %{out} : {partial_vshape} to {new_vshape}") - init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {mlir_dtype} to {new_reduced_shape}") - out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(self.reduction_info[value][0], out, init_vec, axis=0, shape=new_vshape, reduced_shape=new_reduced_shape)) - out2 = self.cse.generate(self.reductions_suffix, f"vector.shuffle %{out}, %{out} [1, 0] : {new_reduced_shape}, {new_reduced_shape}") + with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse): + body_index_var = ops.constant(i, "index") + partial_zero_var_list[-2] = f"%{body_index_var}" + compute_index_var = ",".join(partial_zero_var_list) - self.compute, self.reductions_suffix = self.reductions_suffix, self.compute - self.register_var_info(out, [new_vec_size, mlir_dtype]) + with self.override_buffer_cse(buffer=self.reductions_suffix): + out = ops._load(partial_vec_size, mlir_dtype, value, compute_index_var, partial_tile_shape) + ops._store(init_vec, value, compute_index_var, partial_tile_shape) # Clear the partial buffer to zero + + # 2 step reduction + new_vec_size = 2 + new_reduced_shape = f"vector<{new_vec_size}x{mlir_dtype}>" + reduction_type = self.reduction_info[value][0] + out = ops.multi_reduction(out, init_vec2, partial_vec_size, new_vec_size, partial_vshape, reduction_type, mlir_dtype) + + out2 = self.cse.generate(self.reductions_suffix, f"vector.shuffle %{out}, %{out} [1, 0] : {new_reduced_shape}, {new_reduced_shape}") self.register_var_info(out2, [new_vec_size, mlir_dtype]) - out = reduction_partial_combine_vec(self.reduction_info[value][0], out, out2) - self.compute, self.reductions_suffix = self.reductions_suffix, self.compute - - if self.welford_reduce_out is not None: - # NOTE: It not a real welford algorithm... We just used E(X^2) - E(X)^2 - divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.r_dim_size)} : f32") - if self.buffer_types[name][1] > 1: - divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to {new_reduced_shape}") - else: - divider_vec = divider - if self.current_node.node.origin_node: # FIXME: This is a temporary solution - # mean = SUM(X) / N - self.reduction_mean.append(self.cse.generate(self.reductions_suffix, f"arith.divf %{out}, %{divider_vec} : {new_reduced_shape}")) - out = self.reduction_mean[i] - else: - # m2 = (E(X^2) - E(X)^2) * N - sqr_mean = self.cse.generate(self.reductions_suffix, f"arith.divf %{out}, %{divider_vec} : {new_reduced_shape}") - mean_sqr = self.cse.generate(self.reductions_suffix, f"arith.mulf %{self.reduction_mean[i]}, %{self.reduction_mean[i]} : {new_reduced_shape}") - variance = self.cse.generate(self.reductions_suffix, f"arith.subf %{sqr_mean}, %{mean_sqr} : {new_reduced_shape}") - m2 = self.cse.generate(self.reductions_suffix, f"arith.mulf %{variance}, %{divider_vec} : {new_reduced_shape}") - out = m2 - - final_zero_var_list[-1] = f"%{body_index_var}" - final_compute_index_var = ",".join(final_zero_var_list) - operation = "affine.vector_store" - line = f"{operation} %{out}, %{sram_var}[{final_compute_index_var}] : {final_tile_shape}, {new_reduced_shape}" - self.reductions_suffix.writeline(DeferredLine(name, line)) + with self.override_buffer_cse(buffer=self.reductions_suffix): + out = reduction_partial_combine_vec(self.reduction_info[value][0], out, out2) + + if self.welford_reduce_out is not None: + # NOTE: It not a real welford algorithm... We just used E(X^2) - E(X)^2 + divider = ops.constant(float(self.r_dim_size), "f32") + if self.buffer_types[name][1] > 1: + divider_vec = ops.broadcast(divider, new_vec_size) + else: + divider_vec = divider + + if self.current_node.node.origin_node: # FIXME: This is a temporary solution + # mean = SUM(X) / N + self.reduction_mean.append(ops.truediv(out, divider_vec)) + out = self.reduction_mean[i] + else: + # m2 = (E(X^2) - E(X)^2) * N + sqr_mean = ops.truediv(out, divider_vec) + mean_sqr = ops.mul(self.reduction_mean[i], self.reduction_mean[i]) + variance = ops.sub(sqr_mean, mean_sqr) + m2 = ops.mul(variance, divider_vec) + out = m2 + + final_zero_var_list[-1] = f"%{body_index_var}" + final_compute_index_var = ",".join(final_zero_var_list) + ops._store(out, sram_var, final_compute_index_var, final_tile_shape, buffer_name=name) # MVOUT Encoding # Generate DMA instruction - attribute = f"{{dram_stride={dram_stride}, sram_stride={final_tile_stride}, padding=0}}" + attribute = mlir_common.format_dma_op_attributes(dram_stride, final_tile_stride, 0) code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, dram_shape, final_tile_shape, attribute) self.reductions_suffix.writeline(DeferredLine(name, code)) @@ -1125,18 +1264,18 @@ def set_tile_size(self, template_fusion_info, prologue=False): numel_per_lane = tile_desc.get_numel_per_lane() r_tile_size = tile_desc.get_tile_size()[-1] nr_outer_loop = (numel_per_lane + r_tile_size-1) // r_tile_size - tile_desc.vmap.forced_vec_size = nr_outer_loop * 32 # Why? Emprically selected, other option failed to functionality... + tile_desc.vmap.forced_vec_size = self.get_safe_vec_size(nr_outer_loop * 32) # Why? Emprically selected, other option failed to functionality... self.reduction_fusion = True self.r_tile_size = tile_desc.get_tile_size()[-1] self.r_dim_size = template_fusion_info['r_dim_size'] self.reduction_nr_outer_loop = nr_outer_loop - self.reduction_loop_idx = "reduce_loop_idx" + self.reduction_loop_idx = self.register_var_cse("reduce_loop_idx", 1, "index") self.compute_body_loop.size = r_tile_size self.compute_body_loop.step = tile_desc.get_compute_vec_size() // nr_outer_loop self.reduction_body_loop = mlir_common.LoopLevel(self.reduction_loop_idx, nr_outer_loop) else: - tile_desc.vmap.forced_vec_size = 64 + tile_desc.vmap.forced_vec_size = self.get_safe_vec_size(64) if prologue: self.prologue_compute_body_loop.size = tile_desc.get_numel_per_lane() @@ -1146,15 +1285,16 @@ def set_tile_size(self, template_fusion_info, prologue=False): self.compute_body_loop.step = tile_desc.get_compute_vec_size() return tile_desc - def rename_indexing(self, index) -> sympy.Expr: - for dim_name, dim_aliased_name in self.dim_aliasing.items(): - index = index.subs(sympy.Symbol(dim_name), sympy.Symbol("tmp_"+dim_aliased_name)) - # To avoid this case ({"index0":"index1", "index1":"index0"}) - for dim_aliased_name in self.dim_aliasing.values(): - index = index.subs(sympy.Symbol("tmp_"+dim_aliased_name), sympy.Symbol(dim_aliased_name)) - return index - class MLIRTemplateCaller(CUDATemplateCaller): + def __init__(self, name, category, input_nodes, layout, make_kernel_render, supports_epilogue_fusion, template, info_kwargs, description): + bmreq = MLIRBenchmarkRequest( + kernel_name=name, + input_tensor_meta=list(), + output_tensor_meta=list(), + extra_args=[], + source_code="", + ) + super().__init__(name, category, input_nodes, layout, make_kernel_render, bmreq, supports_epilogue_fusion, template, info_kwargs, description) def __str__(self): return f"MLIRTemplateCaller(source_file={self.bmreq.source_file})" @@ -1177,9 +1317,15 @@ def __init__(self, name, input_nodes, layout, input_reorder = None): """ super().__init__(name) self.input_nodes = [node for node in input_nodes if node is not None] - self.output_node: Buffer = Buffer("buf_out", layout) + self.output_node: Buffer = Buffer(name="buf_out", layout=layout) + # Multi-output templates can override this with explicit output buffers. + self.output_nodes = [self.output_node] self.input_reorder = input_reorder self.layout = layout + # Fusion support flags (default to False) + self.support_epilogue_fusion = False + self.support_prologue_fusion = False + self.support_reduction_fusion = False def generate(self, **kwargs) -> ChoiceCaller: kernel_name = f"mlir_{self.name}" @@ -1191,15 +1337,8 @@ def generate(self, **kwargs) -> ChoiceCaller: code = self.render(kernel=kernel, **kwargs) kernel_hash_name = f"mlir_{self.name}_{next(self.index_counter)}" - extra_args = [] # create the BenchmarkRequest - bmreq = MLIRBenchmarkRequest( - kernel_name=kernel_name, - input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), - output_tensor_meta=TensorMeta.from_irnodes(self.output_node), - extra_args=extra_args, - source_code=code, - ) + output_nodes = getattr(self, "output_nodes", None) or [self.output_node] def make_kernel_render( template_node: TemplateBuffer, @@ -1241,8 +1380,10 @@ def make_kernel_render( self.input_nodes, self.output_node.get_layout(), make_kernel_render, - bmreq, + False, # supports_epilogue_fusion self, + kwargs, + "" # Currently Empty description ) def get_tile_candidates(self, **kwargs): diff --git a/README.md b/README.md index 103131c1..03041355 100644 --- a/README.md +++ b/README.md @@ -99,16 +99,15 @@ The `tests` directory contains several AI workloads examples. ```bash python tests/test_matmul.py ``` -The result is stored to `TORCHSIM_DUMP_PATH/hash/togsim_result/`. The log file contains detailed core, memory, and interconnect stats. +The result is stored to `TORCHSIM_LOG_PATH/hash/togsim_result/`. The log file contains detailed core, memory, and interconnect stats. ### Run Your Own Model on PyTorchSim You can run your own PyTorch model on PyTorchSim by setting up a custom NPU device. This method also applies when you want to simulate models beyond the provided examples. ```python import torch -from Scheduler.scheduler import PyTorchSimRunner -# Declare a custom NPU device -device = PyTorchSimRunner.setup_device().custom_device() + +device = torch.device("npu:0") # Declare you own model (e.g. resnet18 from torchvision) from torchvision.models import resnet18 @@ -197,9 +196,9 @@ Log contains memory & core stats. [2025-12-05 08:05:52.538] [info] Total execution cycles: 2065 [2025-12-05 08:05:52.538] [info] Wall-clock time for simulation: 0.147463 seconds ``` -The log is dumped in `TORCHSIM_DUMP_PATH` and you can set the path as below. +The log is dumped in `TORCHSIM_LOG_PATH` and you can set the path as below. ```bash -export TORCHSIM_DUMP_PATH=/tmp/torchinductor # output file dump path +export TORCHSIM_LOG_PATH=/tmp/torchinductor # output file dump path ``` ## Training @@ -215,76 +214,95 @@ opt_step() `tests/test_mlp.py` provides an example of MLP training. ## Multi-tenancy -Our load generator supports multi-tenancy experiments. You can run a simple example by executing `tests/test_scheduler.py`. -```bash -python tests/test_scheduler.py -``` -Below is an example code of multi-tenancy `resnet18` and `EncoderBlock`. -In this example, the `Scheduler` is initialized with a number of request queues, a scheduling policy, and a TOGSimulator config file(`.json`). The compiled PyTorch models are then registered with a unique model id. -```python3 -import os -import sys +While the **`with TOGSimulator(config_path=...)`** block is active, **`TOGSIM_CONFIG`** is set to that YAML so **compilation and TOGSim use the same** hardware description. + +### 1. One TOGSim session, one continuous log + +If you want **one** log where kernels are simulated **in sequence** as a single run, wrap the code you already use to execute the compiled model with **`with TOGSimulator(config_path=...)`**. No other API is required; every forward inside the block shares that session. + +```python import torch -from torchvision.models import resnet18 -base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') -config = f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.json' +from Simulator.simulator import TOGSimulator -sys.path.append(base_path) -from tests.test_transformer import EncoderBlock -from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request, poisson_request_generator -scheduler = Scheduler(num_request_queue=2, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) +# ... build model, torch.compile, tensors on npu:0 as usual ... -# Register compiled model -target_model0 = resnet18().eval() -target_model1 = EncoderBlock(768, 12).eval() -opt_model0 = torch.compile(target_model0.to(device=scheduler.execution_engine.module.custom_device(), memory_format=torch.channels_last)) -opt_model1 = torch.compile(target_model1.to(device=scheduler.execution_engine.module.custom_device())) -SchedulerDNNModel.register_model("model0", opt_model0) -SchedulerDNNModel.register_model("model1", opt_model1) +with TOGSimulator(config_path=config): + y = compiled_model(x) ``` -The config file(`.json`) specifies two key items: -- `num_partition`: The total number of independent request queues to create. -- `partition`: Defines the hardware mapping, assigning each queue (identified by its index) to a specific physical core. -For example, the configuration below creates two scheduling queues (`0` and `1`) and maps `core_0` to queue `0` and `core_1` to queue `1`: +### 2. Multi-tenancy and explicit scheduling (`launch_model`) + +For **multi-tenant** or **interleaved** execution, you usually need to attach a **timestamp** and a **`stream_index`** to each launch so the simulator can order work correctly. Use **`torch.npu.launch_model(compiled_model, *inputs, stream_index=..., timestamp=...)`** for that; plain `compiled_model(x)` does not carry those parameters. + +**`stream_index`** is the **request-queue / partition index** in the TOGSim config: it must match the **values** in the **`partition`** map (each queue index is mapped to a **core**). For example, `stream_index=0` goes to the queue bound to `core_0`, `stream_index=1` to the queue for `core_1`, and so on. + +**`timestamp`** is in **nanoseconds** (simulation time for ordering launches). Use `0` when you do not need explicit times beyond submission order. + +```python +with TOGSimulator(config_path=config): + torch.npu.launch_model(opt_model1, x1, stream_index=0, timestamp=0) + torch.npu.launch_model(opt_model2, x2, stream_index=1, timestamp=0) + torch.npu.synchronize() + torch.npu.launch_model(opt_model1, x1, stream_index=0, timestamp=0) + torch.npu.launch_model(opt_model2, x2, stream_index=1, timestamp=0) +``` + +Here **`synchronize()`** acts as a barrier: it does not return until every **`launch_model`** issued **above** it has finished in the simulator. The later pair of `launch_model` calls therefore runs only after those earlier models have fully completed—so the sync is the point in the timeline where **all preceding launches are done**. + +```bash +python tests/test_scheduler.py +``` + +Use a TOGSim config(`.yml`) that defines **partitions** when mapping queues to cores, for example: + +- **`num_partition`**: Number of independent request queues (valid **`stream_index`** values are `0 … num_partition-1`). +- **`partition`**: Maps each **core** name to a **queue index**; that index is the same **`stream_index`** you pass to **`launch_model`**. + ``` "num_partition" : 2, "partition": { - "core_0":0, - "core_1":1 + "core_0": 0, + "core_1": 1 } ``` -Next, DNN model requests are generated and submitted. We provide a `poisson_request_generator` utility, which generates request arrival times. -Each `Request` is created with its model name, data, and a request_queue_idx to specify its target queue, then added via `scheduler.add_request`. -As shown in the code, `model0` requests are queued to `request_queue_idx=0`, while `model1` requests are queued to `request_queue_idx=1`. -```python3 -# Load Generation +Here `stream_index=0` selects queue `0` (core_0), `stream_index=1` selects queue `1` (core_1). + +### 3. Load generation (Poisson arrivals) + +The **`poisson_request_generator`** in **`Scheduler.scheduler`** yields synthetic **arrival times** (in **milliseconds**). Merge those with **`launch_model`**: convert each time to **nanoseconds** for **`timestamp`**, set **`stream_index`** to the target partition queue, and run all launches inside one **`with TOGSimulator(...)`** so a **single** log captures the full trace. + +```python +from Scheduler.scheduler import poisson_request_generator + model0_lambda = 5.0 model1_lambda = 3.0 -max_time = 1000.0 # [s] +max_time_msec = 1000.0 # Poisson horizon [ms] -# Generate Possion distribution requests for model0 -for model0_request_time in poisson_request_generator(model0_lambda, max_msec_time=max_time): - x = torch.randn(1, 3, 224, 224) - new_request = Request("model0", [x], [], request_queue_idx=0) - scheduler.add_request(new_request, request_time=model0_request_time) +events = [] +for t in poisson_request_generator(model0_lambda, max_msec_time=max_time_msec): + x = torch.randn(1, 3, 224, 224, device=device) + events.append((t, 0, opt_model0, (x,))) # stream_index 0 → queue / partition 0 -# Generate Possion distribution requests for model1 -for model1_request_time in poisson_request_generator(model1_lambda, max_msec_time=max_time): - x = torch.randn(128, 768) - new_request = Request("model1", [x], [], request_queue_idx=1) - scheduler.add_request(new_request, request_time=model1_request_time) -``` +for t in poisson_request_generator(model1_lambda, max_msec_time=max_time_msec): + x = torch.randn(128, 768, device=device) + events.append((t, 1, opt_model1, (x,))) # stream_index 1 → queue / partition 1 -Finally, `scheduler.schedule()` is called in a loop until all requests are processed. -```python3 -# Run scheduler -while not scheduler.is_finished(): - scheduler.schedule() +events.sort(key=lambda e: e[0]) + +with TOGSimulator(config_path=config): + for t_msec, stream_index, model, args in events: + torch.npu.launch_model( + model, + *args, + stream_index=stream_index, + timestamp=int(t_msec * 1e6), + ) # ms → ns ``` +The two Poisson streams are **combined and sorted by time** so launches follow a single global arrival order. + ## Compiler Optimizations PyTorchSim compiler supports several fusion optimizations: - GEMM prologue fusion @@ -396,7 +414,6 @@ export TORCHSIM_USE_TIMING_POOLING=0 # use lightweight pooling for timing "icnt_injection_ports_per_core" : 16 // Interconnect injection ports per core "icnt_config_path" : "../configs/booksim2_configs/fly_c4_m32.icnt", // Booksim2 config file path - "precision" : 4, // Element's precision in tensor (Byte) "scheduler" : "simple", // Scheduler type (Now, only support simple scheduler) "num_partition" : 2, // Multi-core Partitioning "partition": { // allocate request queue index @@ -415,7 +432,7 @@ export TORCHSIM_USE_TIMING_POOLING=0 # use lightweight pooling for timing ``` You can set TOGSim config path as below. ```bash -export TORCHSIM_CONFIG=/workspace/PyTorchSim/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json +export TOGSIM_CONFIG=/workspace/PyTorchSim/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml ``` ## Future Works Currently, PyTorchSim supports PyTorch 2.2. Support for newer versions will be added soon. diff --git a/Scheduler/scheduler.py b/Scheduler/scheduler.py index ffe8e4fc..2b3aac92 100644 --- a/Scheduler/scheduler.py +++ b/Scheduler/scheduler.py @@ -1,29 +1,11 @@ -from typing import List -import os -import numpy as np -import torch -from pathlib import Path -import importlib.util -from PyTorchSimFrontend.extension_codecache import hash_prefix -from Simulator.simulator import TOGSimulator -from PyTorchSimFrontend import extension_config - -def import_module_from_path(module_name, path): - module_path = Path(path) # Convert to Path object for safety - if not module_path.exists() or not module_path.is_file(): - raise FileNotFoundError(f"No such file: '{module_path}'") - - spec = importlib.util.spec_from_file_location(module_name, module_path) - if spec is None: - raise ImportError(f"Could not load module from path: '{module_path}'") +"""Poisson load helpers for synthetic request arrival times.""" - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) +import numpy as np - return module def poisson_request_generator(lambda_requests, max_msec_time=None): - current_time = 0.0 # msec + """Yield synthetic arrival times in milliseconds (first sample is 0).""" + current_time = 0.0 # msec yield 0 while max_msec_time is None or current_time < max_msec_time: @@ -34,512 +16,3 @@ def poisson_request_generator(lambda_requests, max_msec_time=None): break yield current_time - -class Request: - """ Each request has model name, it's own id, and requested time. """ - request_id = 0 - QUEUED = 1 - RUNNING = 2 - INCREMENT = 3 - FINISHED = 4 - def __init__(self, model:str, batchable_input_tensor : List[torch.Tensor], - shared_input_tensor: List[torch.tensor], request_queue_idx=0) -> None: - self.model = model - self.batchable_input_tensor = batchable_input_tensor - self.shared_input_tensor = shared_input_tensor - self.arrival_time = None - self.start_time = [] - self.finish_time = [] - self.state = self.QUEUED - self.id = self.allocate_id() - self.request_queue_idx = request_queue_idx - - def allocate_id(self): - allocated_id = Request.request_id - Request.request_id += 1 - return allocated_id - - def set_start(self, start_time): - self.state = self.RUNNING - self.start_time.append(start_time) - - def set_finished(self, finish_time): - self.state = self.FINISHED - self.finish_time.append(finish_time) - - def get_latency(self): - # Todo. Provide Toke-By-Token - if self.state == self.FINISHED: - turnaround_time = self.finish_time[-1] - self.arrival_time - else: - turnaround_time = None - - if self.start_time: - response_time = self.start_time[0] - self.arrival_time - else: - response_time = None - - if self.start_time and self.finish_time: - tbt_time = [i-j for i,j in zip(self.finish_time, self.start_time)] - else: - tbt_time = [] - - return turnaround_time, response_time, tbt_time - - def free_memory(self): - """ Free memory resources that are allocated for handle this request """ - return - - def __str__(self) -> str: - return f"Request{self.id} Model: '{self.model}', Arrival: {self.arrival_time}, Start: {self.start_time}, End: {self.finish_time}, State: {self.state}, Partion: {self.request_queue_idx}" - -class RequestReturn: - INCREMENT = 0 - FINISHED = 1 - def __init__(self, state) -> None: - self.state = state - - def is_finished(self): - return self.state == self.FINISHED - - def is_increment(self): - return self.state == self.INCREMENT - -class SchedulerDNNModel: - MODEL_MAP = {} - def __init__(self, batched_req : List[Request], partition_idx) -> None: - self.model_name = batched_req[0].model - self.batched_req = batched_req - self.args = None - self.model = self.find_model(self.model_name) - self.partition_idx = partition_idx - - def find_model(self, model_name : str): - if model_name in SchedulerDNNModel.MODEL_MAP: - return SchedulerDNNModel.MODEL_MAP[model_name] - else: - raise KeyError(f'[Scheduler] Requested model "{model_name}" is not registered...') - - def get_batchable_input(self): - batched_input_tensor = [] - for i in range(len(self.batched_req[0].batchable_input_tensor)): - tensor_list = [req.batchable_input_tensor[i] for req in self.batched_req] - batched_input_tensor.append(torch.concat(tensor_list, dim=0)) - return batched_input_tensor - - def get_shared_input(self): - return self.batched_req[0].shared_input_tensor - - def get_input(self): - return self.get_batchable_input() + self.get_shared_input() - - def __str__(self): - return f"DNN Model: {self.model_name}, Partion idx: {self.partition_idx} Req: {self.batched_req[0]}" - - @staticmethod - def register_model(model_name : str, compiled_model): - SchedulerDNNModel.MODEL_MAP[model_name] = compiled_model - -class PyTorchSimRunner: - PARTITION_BUSY = 0 - PARTITION_IDLE = 1 - SELECT_NOTHING = 2 - def __init__(self, tog_simulator : TOGSimulator, num_partion=1) -> None: - self.module = self.setup_device() - self.num_partion = num_partion - self.launch_model_dicts = [] - self.nested_launch_model_dicts = [] - self.partition_state = [] - for i in range(self.num_partion): - self.launch_model_dicts.append({}) - self.nested_launch_model_dicts.append({}) - self.partition_state.append(self.PARTITION_IDLE) - - self.finish_req_dict = {} - self.tog_simulator = tog_simulator - - # Dry run for compile and create generator - os.environ["TOGSIM_EAGER_MODE"] = "1" - - @staticmethod - def setup_device(): - source_file_path = os.path.dirname(os.path.abspath(__file__)) - source_file = os.path.join( - source_file_path, f"{extension_config.CONFIG_TORCHSIM_DIR}/PyTorchSimFrontend/extension_device.cpp" - ) - - import torch.utils.cpp_extension - module = torch.utils.cpp_extension.load( - name="npu", - sources=[ - str(source_file), - ], - extra_cflags=["-g"], - verbose=True, - ) - - torch.utils.rename_privateuse1_backend("npu") - from torch._inductor.codegen.common import ( - get_scheduling_for_device, - get_wrapper_codegen_for_device, - register_backend_for_device, - ) - from PyTorchSimFrontend.mlir.mlir_codegen_backend import ( - ExtensionWrapperCodegen, - ) - from PyTorchSimFrontend.mlir.mlir_scheduling import ( - MLIRScheduling - ) - register_backend_for_device( - "npu", MLIRScheduling, ExtensionWrapperCodegen - ) - assert( - get_scheduling_for_device("npu") == MLIRScheduling - ) - assert( - get_wrapper_codegen_for_device("npu") - == ExtensionWrapperCodegen - ) - return module - - def submit(self, batched_req, partition_idx) -> List[RequestReturn]: - # FIXME. Construct SchedulerDNNModel - batched_req_model = self.get_compiled_model(batched_req, partition_idx) - self.prepare_model(batched_req_model) - - def get_compiled_model(self, batched_req: List[Request], request_queue_idx): - compiled_model = SchedulerDNNModel(batched_req, request_queue_idx) - return compiled_model - - def is_partition_idle(self, partition_idx): - return len(self.launch_model_dicts[partition_idx]) == 0 - - def is_any_idle(self, skip_list): - return any([self.is_partition_idle(i) and not skip_list[i] for i in range(self.num_partion)]) - - def is_all_idle(self): - return all([self.is_partition_idle(i) for i in range(self.num_partion)]) - - def prepare_model(self, req_model: SchedulerDNNModel): - result_path = os.path.join(extension_config.CONFIG_TORCHSIM_LOG_PATH, "togsim_result", req_model.model_name) - os.makedirs(result_path, exist_ok=True) - index = str(len(os.listdir(result_path))) - - # Prepare input tensor - input_tensor_list = req_model.get_input() - input_tensor_list = [input_tensor.to(device=self.module.custom_device()) for input_tensor in input_tensor_list] - - # This model-call will return generator - ret = req_model.model(*input_tensor_list) - self.launch_model_dicts[req_model.partition_idx][req_model] = ret - - def finish_model(self, model : SchedulerDNNModel, output : torch.Tensor): - for req in model.batched_req: - # TODO. finish time - self.finish_req_dict[req] = RequestReturn(RequestReturn.FINISHED) - - def prepare_launch_kernel(self, kernel, inputs): - result_path, runtime_path, _ = kernel(*inputs) - onnx_path = os.path.join(result_path, "tile_graph.onnx") - - attribute_path = os.path.join(runtime_path, "attribute") - attribute_path = self.tog_simulator.create_attribute_file(attribute_path, inputs) - return onnx_path, attribute_path - - def launch_kernel(self, current_cycle, partion_idx=0): - # Check partition is busy - if self.partition_state[partion_idx] != self.PARTITION_IDLE: - return self.partition_state[partion_idx] - result = self.select_kernel(partion_idx) - if result == self.SELECT_NOTHING: - return self.SELECT_NOTHING - kernel, inputs = result - if not isinstance(kernel, str): - onnx_path, attribute_path = self.prepare_launch_kernel(kernel, inputs) - else: - onnx_path, attribute_path = kernel, inputs - self.partition_state[partion_idx] = self.PARTITION_BUSY - return self.tog_simulator.launch(onnx_path, attribute_path, current_cycle, partion_idx) - -class FIFORunner(PyTorchSimRunner): - def __init__(self, tog_simulator: TOGSimulator, num_partion=1) -> None: - super().__init__(tog_simulator, num_partion) - - def select_kernel(self, partition_idx): - while len(self.nested_launch_model_dicts[partition_idx]) or len(self.launch_model_dicts[partition_idx]): - if len(self.nested_launch_model_dicts[partition_idx]): - target_dict = self.nested_launch_model_dicts - else: - target_dict = self.launch_model_dicts - - # Select FIFO manner - req, target_model = next(iter(target_dict[partition_idx].items())) - try: - kernel, inputs = next(target_model) - - # For extern call - if isinstance(kernel, str): - return kernel, inputs - - # For convolution... - if not hasattr(kernel, "future"): - nested_gen = kernel(*inputs) - self.nested_launch_model_dicts[partition_idx] = {req : nested_gen} - kernel, inputs = \ - next(self.nested_launch_model_dicts[partition_idx][req]) - return kernel, inputs - except StopIteration as e: - # Retry - if target_dict == self.launch_model_dicts: - self.finish_model(req, e.value) - del target_dict[partition_idx][req] - # No proper kernel now - return self.SELECT_NOTHING - -class RoundRobinRunner(PyTorchSimRunner): - def __init__(self, tog_simulator: TOGSimulator, num_partion=1) -> None: - super().__init__(tog_simulator, num_partion) - self.next_pointer = None - - def select_kernel(self, partition_idx): - while len(self.nested_launch_model_dicts[partition_idx]) or len(self.launch_model_dicts[partition_idx]): - if len(self.nested_launch_model_dicts[partition_idx]): - target_dict = self.nested_launch_model_dicts - else: - target_dict = self.launch_model_dicts - - req_list = list(target_dict[partition_idx].keys()) - # Select RR manner - if self.next_pointer is None or self.next_pointer not in req_list: - req = req_list[0] - pos = 0 - else: - req = self.next_pointer - pos = req_list.index(req) - - # Set Next pointer - if pos + 1 < len(req_list): - self.next_pointer = req_list[pos+1] - else: - self.next_pointer = req_list[0] - - target_model = self.launch_model_dicts[partition_idx][req] - try: - kernel, inputs = next(target_model) - - # For convolution... - if not hasattr(kernel, "future"): - nested_gen = kernel(*inputs) - self.nested_launch_model_dicts[partition_idx] = {req : nested_gen} - kernel, inputs = \ - next(self.nested_launch_model_dicts[partition_idx][req]) - return kernel, inputs - except StopIteration as e: - # Retry - if target_dict == self.launch_model_dicts: - self.finish_model(req, e.value) - del self.launch_model_dicts[partition_idx][req] - # No proper kernel now - return self.SELECT_NOTHING - -class Scheduler: - - FIFO_ENGINE = 0 - RR_ENGINE = 1 - def __init__(self, num_request_queue=1, max_batch=1, engine_select=FIFO_ENGINE, togsim_config=extension_config.CONFIG_TOGSIM_CONFIG) -> None: - self.current_cycle = 0 - self.max_batch = max_batch - self.num_request_queue = num_request_queue - self.request_queue : List[List[Request]] = [] - for i in range(self.num_request_queue): - self.request_queue.append([]) - self.finish_queue : List[Request] = [] - - togsim_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "TOGSim") - self.tog_simulator = TOGSimulator(togsim_path, togsim_config) - self.tog_simulator.interactive_simulation() - if engine_select == Scheduler.FIFO_ENGINE: - self.execution_engine = FIFORunner(self.tog_simulator, self.num_request_queue) - elif engine_select == Scheduler.RR_ENGINE: - self.execution_engine = RoundRobinRunner(self.tog_simulator, self.num_request_queue) - else: - print(f"Not supporetd engine type {engine_select}") - exit(1) - - def add_request(self, request: Request, request_time=-1): - """register model at timestamp time - request_time : msec - """ - request_time = self.current_time() if request_time == -1 else request_time - request.arrival_time = request_time - self.request_queue[request.request_queue_idx].append(request) - - def request_empty(self, request_queue_idx): - return len(self.request_queue[request_queue_idx])==0 - - def select(self, request_queue_idx=0) -> List[Request]: - """ - Select 1 request from request_queue in FCFS manner. - If there is no proper request, return None - """ - candidate_req = [] - if not self.request_queue[request_queue_idx]: - return candidate_req - for req in self.request_queue[request_queue_idx]: - - if self.msec_to_cycle(req.arrival_time) <= self.current_cycle and req.state == Request.QUEUED: - candidate_req.append(req) - - # Stop batching - if self.max_batch <= len(candidate_req): - break - return candidate_req - - def next_request_time(self, request_queue_idx=0): - for req in self.request_queue[request_queue_idx]: - if req.state == Request.QUEUED: - return req, req.arrival_time - return None, -1 - - def nearest_next_reqeust_time(self): - nearest_req = None - nearest_arrival_time = -1 - for i in range(self.num_request_queue): - req, arrival_time = self.next_request_time(i) - if nearest_arrival_time == -1 and arrival_time != -1: - nearest_req = req - nearest_arrival_time = arrival_time - elif arrival_time != -1 and nearest_arrival_time > arrival_time: - nearest_req = req - nearest_arrival_time = arrival_time - return nearest_req, nearest_arrival_time - - def finish_request(self, req : Request): - req.set_finished(self.current_time()) - - # Free resources - req.free_memory() - - # Move to finish queue - self.finish_queue.append(req) - self.request_queue[req.request_queue_idx].remove(req) - turnaround_time, response_time, tbt_time = req.get_latency() - print(f"[Request-{req.id} finished] partition: {req.request_queue_idx} arrival_time: " - f"{req.arrival_time} start_time: {req.start_time[0]} turnaround latency: {turnaround_time}, " - f"response time: {response_time} tbt_time: {tbt_time}") - - def per_schedule(self, request_queue_idx): - # Wait partition is idle - if not self.execution_engine.is_partition_idle(request_queue_idx): - return False - - request_list = self.select(request_queue_idx) - if not request_list: - return False - - print(f"[Request issue] partition: {request_queue_idx} batch size: {len(request_list)}", flush=True) - for req in request_list: - req.set_start(self.current_time()) - print(f"[Request-{req.id} issue] partition: {req.request_queue_idx} " - f"arrival_time: {req.arrival_time} start_time: {req.start_time[0]}", flush=True) - # Submit batched request - self.execution_engine.submit(request_list, request_queue_idx) - - return True - - def check_finish_request(self): - # Check finished request - while self.execution_engine.finish_req_dict: - req, req_ret = next(iter(self.execution_engine.finish_req_dict.items())) - self.finish_request(req) - del self.execution_engine.finish_req_dict[req] - - def schedule(self): - # Try schedule all request queue - result = [] - for i in range(self.num_request_queue): - result.append(self.per_schedule(i)) - - # Try move to next nearest request time - next_req, next_time = self.nearest_next_reqeust_time() - if next_req is None and self.execution_engine.is_all_idle(): - # No request remained... - return - - # Need to forward the time until next_arrival_time - if self.execution_engine.is_all_idle(): - reason = self.tog_simulator.until(self.msec_to_cycle(next_time)) - self.current_cycle = self.tog_simulator.cycle() - else: - self.run(next_time) - return - - def run(self, until_time): - req_empty_info = [self.request_empty(i) for i in range(self.execution_engine.num_partion)] - def execute_cycle(): - launch_ret_info = [] - for i in range(self.execution_engine.num_partion): - if self.execution_engine.partition_state[i] == PyTorchSimRunner.PARTITION_IDLE: - ret = self.execution_engine.launch_kernel(self.current_cycle, i) - launch_ret_info.append(ret) - - self.check_finish_request() - # Check if the stop condition is met - if self.execution_engine.is_any_idle(req_empty_info) or self.execution_engine.is_all_idle(): # Ignore empty request queue - return [] - - # Schedule jobs and update the current time - result_list = self.tog_simulator.until(self.msec_to_cycle(until_time)) - self.current_cycle = self.tog_simulator.cycle() - - for core_idx in result_list: - # Kernel is finished. So set idle state - self.execution_engine.partition_state[core_idx] = PyTorchSimRunner.PARTITION_IDLE - - return result_list - - if self.current_cycle >= self.msec_to_cycle(until_time): - until_time = -1 - - if until_time == -1: - while not self.execution_engine.is_any_idle(req_empty_info): - result = execute_cycle() - req_empty_info = [self.request_empty(i) for i in range(self.execution_engine.num_partion)] - # if result is not -1, schedule new request - if len(result)==0: - break - - else: - while self.current_cycle <= self.msec_to_cycle(until_time) and not self.execution_engine.is_all_idle(): - result = execute_cycle() - # if result is not -1, schedule new request - if len(result)==0: - break - return - - def is_request_queue_empty(self): - result = True - for i in range(self.num_request_queue): - result = result and (not len(self.request_queue[i])) - return result - - def is_finished(self): - if self.is_request_queue_empty() and self.execution_engine.is_all_idle(): - self.tog_simulator.wait() - return True - return False - - def current_time(self): - return self.cycle_to_msec(self.current_cycle) - - def cycle_to_msec(self, cycle): - freq = self.tog_simulator.get_core_freq() - return cycle / (freq / 1000) - - def msec_to_cycle(self, msec): - # We treat -1 as special time - if (msec == -1): - return msec - - freq = self.tog_simulator.get_core_freq() - return int(msec * (freq / 1000)) diff --git a/Simulator/simulator.py b/Simulator/simulator.py index 322d9b12..5b00d5d4 100644 --- a/Simulator/simulator.py +++ b/Simulator/simulator.py @@ -4,11 +4,12 @@ import subprocess import re import sys -import json +import yaml import time import datetime import threading from pathlib import Path +import uuid import torch import numpy as np @@ -16,6 +17,47 @@ from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs from PyTorchSimFrontend import extension_config +# Configure logger for Simulator module +logger = extension_config.setup_logger() +from tqdm import tqdm + + +class ProgressBar: + def __init__(self, desc, silent_mode=False, update_interval=0.5): + self.desc = desc + self.silent_mode = silent_mode + self.update_interval = update_interval + self.pbar = None + self.finished = False + self.progress_thread = None + + def __enter__(self): + if not self.silent_mode: + self.pbar = tqdm( + desc=self.desc, + bar_format='{desc}: {elapsed}', + leave=False, # Don't leave the bar when done (it will disappear) + ncols=80, + disable=False, + total=100, # Use a total for smooth animation + ) + # Update progress bar in a separate thread + def update_progress(): + while not self.finished: + self.pbar.update(1) + time.sleep(self.update_interval) + + self.progress_thread = threading.Thread(target=update_progress, daemon=True) + self.progress_thread.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.finished = True + if not self.silent_mode and self.pbar is not None: + self.pbar.close() + return False + + TORCH_TO_NUMPY = { torch.float32: np.float32, torch.float64: np.float64, @@ -26,6 +68,7 @@ torch.uint8: np.uint8, torch.bool: np.uint8, torch.bfloat16: np.float16, + torch.float16: np.float16, } class FunctionalSimulator(): @@ -53,7 +96,7 @@ def write_arg(self, arg, path, name): tensor = arg.cpu().detach() buffer_size = tensor.untyped_storage().size() buffer = (ctypes.c_char * buffer_size).from_address(tensor.data_ptr()) - t_arr = np.frombuffer(buffer, dtype=tensor.numpy().dtype, count=buffer_size // tensor.element_size()) + t_arr = np.frombuffer(buffer, dtype=TORCH_TO_NUMPY[tensor.dtype], count=buffer_size // tensor.element_size()) t_arr.tofile(data_path) else: assert(0) @@ -101,18 +144,19 @@ def run_spike(self, args, arg_attributes, runtime_path, binary, vectorlane_size= base_path= f"--base-path={runtime_path}" os.makedirs(os.path.join(runtime_path, "indirect_access"), exist_ok=True) os.makedirs(os.path.join(runtime_path, "dma_access"), exist_ok=True) - run = f'spike --isa rv64gcv --varch=vlen:256,elen:64 {vectorlane_option} {spad_option} {kernel_address} {base_path} /workspace/riscv-pk/build/pk {target_binary} {file_path_str}' - if not silent_mode and extension_config.CONFIG_DEBUG_MODE: - print("[Spike] cmd> ", run) - print("[Spike] Running Spike simulator") + run = f'spike --isa rv64gcv_zfh --varch=vlen:256,elen:64 {vectorlane_option} {spad_option} {kernel_address} {base_path} /workspace/riscv-pk/build/pk {target_binary} {file_path_str}' + if not silent_mode: + logger.debug(f"[Spike] cmd> {run}") + logger.info("[Spike] Running Spike simulator") run_cmd = shlex.split(run) try: stdout_setting = subprocess.DEVNULL if silent_mode else None stderr_setting = subprocess.DEVNULL if silent_mode else None - subprocess.check_call(run_cmd, stdout=stdout_setting, stderr=stderr_setting) + with ProgressBar("[Spike] Running simulation", silent_mode=silent_mode): + subprocess.check_call(run_cmd, stdout=stdout_setting, stderr=stderr_setting) except subprocess.CalledProcessError as e: if not silent_mode: - print("[Spike] Command failed with exit code", e.returncode) + logger.error(f"[Spike] Command failed with exit code {e.returncode}") error_msg = "" if e.returncode == 200: error_msg = "INVALID_SPAD_ACCESS" @@ -151,39 +195,22 @@ class CycleSimulator(): def __init__(self) -> None: pass - def compile_and_simulate(self, target_binary, array_size, vectorlane_size, silent_mode=False): - def show_progress(): - i = 0 - while not finished: - i = (i + 1) % 3 - tail = "." * i + " " * (3-i) - sys.stdout.write("\r[Gem5] Gem5 is running." + tail) - time.sleep(1) - print("") - + def compile_and_simulate(self, target_binary, vectorlane_size, silent_mode=False): dir_path = os.path.join(os.path.dirname(target_binary), "m5out") gem5_script_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "gem5_script/script_systolic.py") gem5_cmd = [extension_config.CONFIG_GEM5_PATH, "-r", "--stdout-file=sto.log", "-d", dir_path, gem5_script_path, "-c", target_binary, "--vlane", str(vectorlane_size)] + + if not silent_mode: + logger.debug(f"[Gem5] cmd> {' '.join(gem5_cmd)}") + logger.info("[Gem5] Gem5 simulation started") + try: - # Create progress thread - is_dryrun = int(os.environ.get('TOGSIM_EAGER_MODE', default=False)) or silent_mode - if not is_dryrun: - if extension_config.CONFIG_DEBUG_MODE: - print("[Gem5] cmd> ", " ".join(gem5_cmd)) - finished = False - progress_thread = threading.Thread(target=show_progress) - progress_thread.start() - output = subprocess.check_output(gem5_cmd, stderr=subprocess.DEVNULL) - finished = True - progress_thread.join() - else: - output = subprocess.check_output(gem5_cmd, stderr=subprocess.DEVNULL) + #with ProgressBar("[Gem5] Running simulation", silent_mode=is_dryrun): + output = subprocess.check_output(gem5_cmd, stderr=subprocess.DEVNULL) except subprocess.CalledProcessError as e: - print(f"[Gem5] Gem5 simulation failed with error: \"{e.output.decode()}\"") - if not is_dryrun: - finished = True - progress_thread.join() - raise RuntimeError(f"Gem5 Simulation Failed: \"{e.output.decode()}\"") + output_error = e.output.decode() if isinstance(e.output, bytes) else str(e.output) + logger.debug(f"[Gem5] Gem5 simulation failed with error: \"{output_error}\"") + raise RuntimeError(f"Gem5 Simulation Failed: \"{output_error}\"") with open(f"{dir_path}/stats.txt", "r") as stat_file: raw_list = stat_file.readlines() @@ -196,139 +223,215 @@ class TOGSimulator(): TOGSIM_RESULT_PATH_KEY = "TOGSIM_RESULT_PATH" FINISH_STR = "Simulation finished" ALLOC_POOL = dict() # For eagermode buffer plan - def __init__(self, togsim_path, config_path, vectorlane_size=-1) -> None: + _TOGSIM_CONFIG_ENV_UNSET = object() + def __init__(self, config_path=None, togsim_path=None) -> None: + if config_path is None: + config_path = extension_config.CONFIG_TOGSIM_CONFIG + if togsim_path is None: + togsim_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "TOGSim") + self.base_dir = togsim_path self.config_path = config_path - self.config_json = self.load_json(self.config_path) + self.config_yaml = self.load_yaml(self.config_path) self.process = None - self.vectorlane_size = vectorlane_size + self._next_kernel_id = 0 # Auto-incrementing kernel ID - def get_togsim_command(self): - bin = os.path.join(self.base_dir, "build/bin/Simulator") - config = os.path.join(self.base_dir, self.config_path) - cmd = f"{bin} --config {config}" - return cmd + # Create FIFOs for command and event communication + self.fifo_dir = os.path.join("/tmp", f"togsim_fifo_{os.getpid()}") + os.makedirs(self.fifo_dir, exist_ok=True) + self.trace_file_path = os.path.join(self.fifo_dir, "cmd_fifo") + self.trace_log = "# command_type, kernel_id, device_index, stream_index, tog_path, attribute_path, timestamp\n" - def simulation(self, model_path, attribute_path="", silent_mode=False): - def show_progress(): - i = 0 - while not finished: - i = (i + 1) % 3 - tail = "." * i + " " * (3-i) - sys.stdout.write("\r[TOGSim] TOGSim is running." + tail) - time.sleep(1) - print("") - cmd = f"{self.get_togsim_command()} --models_list {model_path}" - if extension_config.CONFIG_TOGSIM_DEBUG_LEVEL: - cmd += f" --log_level {extension_config.CONFIG_TOGSIM_DEBUG_LEVEL}" - if attribute_path: - cmd = f"{cmd} --attributes_list {attribute_path}" - if not silent_mode and extension_config.CONFIG_DEBUG_MODE: - print("[TOGSim] cmd> ", cmd) + # Create FIFOs if they don't exist + if os.path.exists(self.trace_file_path): + os.remove(self.trace_file_path) + os.mkfifo(self.trace_file_path) - # Create progress thread - if not silent_mode: - finished = False - progress_thread = threading.Thread(target=show_progress) - progress_thread.start() + # Start TOGSim process + self._start_process() + + # Open trace file FIFO once and keep it open (after process starts) + self._trace_file_lock = threading.Lock() try: - result = subprocess.check_output(shlex.split(cmd)) - if not silent_mode: - finished = True - progress_thread.join() - except subprocess.CalledProcessError as e: - if not silent_mode: - finished = True - progress_thread.join() - print("[TOGSim] Command failed with exit code", e.returncode) - print("[TOGSim] Error output:", e.output) - assert 0 - # Save result to result_path - result_path = extension_config.CONFIG_TORCHSIM_LOG_PATH - os.makedirs(result_path, exist_ok=True) - file_name = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')+".log" - result_path = os.path.join(result_path, file_name) - with open(result_path, "w") as f: - f.write(result.decode()) - if not silent_mode or extension_config.CONFIG_DEBUG_MODE: - model_path_log = f' of "{model_path}" ' if extension_config.CONFIG_DEBUG_MODE else " " - print(f'[TOGSim] Simulation log{model_path_log}is stored to "{result_path}"') - return result_path + self._trace_file_handle = open(self.trace_file_path, 'w') + except IOError as e: + logger.error(f"[TOGSim] Failed to open trace file: {e}") + raise RuntimeError(f"Failed to open trace file: {e}") + + def __enter__(self): + """Context manager entry. + + Sets ``TOGSIM_CONFIG`` to this instance's config path so that compilation + (``extension_config`` / codegen) uses the same YAML as TOGSim. Previous + value is restored in ``__exit__``. + """ + if "TOGSIM_CONFIG" in os.environ: + self._old_togsim_config_env = os.environ["TOGSIM_CONFIG"] + else: + self._old_togsim_config_env = self._TOGSIM_CONFIG_ENV_UNSET + os.environ["TOGSIM_CONFIG"] = os.path.abspath(self.config_path) + + self.old_tog_simulator = torch.npu.get_tog_simulator() + torch.npu.set_tog_simulator(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - automatically cleanup.""" + self.until() + torch.npu.set_tog_simulator(self.old_tog_simulator) + + if self._old_togsim_config_env is self._TOGSIM_CONFIG_ENV_UNSET: + os.environ.pop("TOGSIM_CONFIG", None) + else: + os.environ["TOGSIM_CONFIG"] = self._old_togsim_config_env - def interactive_simulation(self): - cmd = f"{self.get_togsim_command()} --mode interactive" + def _start_process(self): + cmd = f"{self.get_togsim_command(self.config_path, self.base_dir)} --models_list {self.trace_file_path}" if extension_config.CONFIG_TOGSIM_DEBUG_LEVEL: cmd += f" --log_level {extension_config.CONFIG_TOGSIM_DEBUG_LEVEL}" - if extension_config.CONFIG_DEBUG_MODE: - print("[TOGSim] cmd> ", cmd) + logger.debug(f"[TOGSim] cmd> {cmd}") if self.process is None: self.process = subprocess.Popen( shlex.split(cmd), - stdin=subprocess.PIPE, - stderr=subprocess.PIPE, + #stdout=subprocess.PIPE, + #stderr=subprocess.PIPE, universal_newlines=True ) else: - print("[TOGSim] Simulator is already running.") + logger.warning("[TOGSim] Simulator is already running.") - def stop(self): - if self.process: - self.process.terminate() - self.process.wait() - self.process = None - print("[TOGSim] Simulator stopped.") + def _cleanup_fifos(self): + """Clean up FIFO files""" + try: + if os.path.exists(self.trace_file_path): + os.remove(self.trace_file_path) + if os.path.exists(self.fifo_dir): + os.rmdir(self.fifo_dir) + except OSError as e: + logger.warning(f"[TOGSim] Failed to clean up FIFOs: {e}") + + def _send_command(self, command_type, device_index, stream_index, tog_path="", attribute_path="", timestamp=0): + """ + Internal method to send a command to TOGSim via FIFO. + + Args: + command_type: Type of command ("LAUNCH_KERNEL" or "DEVICE_SYNC") + device_index: Device index + stream_index: Stream index + tog_path: Path to TOG file (ONNX model) - empty for DEVICE_SYNC + attribute_path: Path to attribute file - empty for DEVICE_SYNC + timestamp: Timestamp in nanoseconds (default: 0) + + Returns: + int: The kernel ID assigned to this command + """ + if self.process is None: + raise RuntimeError("[TOGSim] Simulator process is not running") + + if self.process.poll() is not None: + raise RuntimeError("[TOGSim] Simulator process has terminated") + + # Get and increment kernel ID + kernel_id = self._next_kernel_id + self._next_kernel_id += 1 + + # Format command: command_type,kernel_id,device_index,stream_index,tog_path,attribute_path,timestamp + command = f"{command_type},{kernel_id},{device_index},{stream_index},{tog_path},{attribute_path},{timestamp}" + + with self._trace_file_lock: + # Write command to TOGSim + try: + self._trace_file_handle.write(command + '\n') + self._trace_file_handle.flush() + self.trace_log += command + '\n' + logger.debug(f"[TOGSim] Sent command: {command}") + except IOError as e: + logger.error(f"[TOGSim] Failed to write to trace file: {e}") + raise RuntimeError(f"Failed to send command to TOGSim: {e}") + return kernel_id + + def until(self): + # Make sure that all kernels in the stream are finished + torch.npu.synchronize() + + # Close trace file handle if open + if self._trace_file_handle is not None: + try: + self._trace_file_handle.close() + except: + pass + self._trace_file_handle = None - def wait(self): if self.process: - print("[TOGSim] Waiting for simulation to complete...") - self.quit() self.process.wait() + + # Read output streams + stdout_output = "" + stderr_output = "" + if self.process.stdout: + stdout_output = self.process.stdout.read() + if self.process.stderr: + stderr_output = self.process.stderr.read() + + # Print stderr immediately if there's any error output + if stderr_output: + sys.stderr.write(stderr_output) + sys.stderr.flush() + + # Save stdout to result file + if stdout_output: + result_path = extension_config.CONFIG_TORCHSIM_LOG_PATH + os.makedirs(result_path, exist_ok=True) + file_name = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + ".log" + result_path = os.path.join(result_path, file_name) + with open(result_path, "w") as f: + f.write(stdout_output) + logger.info(f'[TOGSim] Simulation log is stored to "{result_path}"') self.process = None - print("[TOGSim] Simulation completed.") - def send_command(self, command): - if self.process: - try: - if extension_config.CONFIG_TORCHSIM_DEBUG_MODE: - print(command, flush=True) - self.process.stdin.write(command + '\n') - self.process.stdin.flush() - ret = self.process.stderr.readline().strip() - return ret - except BrokenPipeError: - err = self.process.stderr.readlines() - for line in err: - print(line) - self.process = None - exit(1) - else: - print("Simulator is not running.") - return None - - def launch(self, onnx_path, attribute_path, arrival_time=0, partion_id=0): - command = f"launch {self.config_path} {onnx_path} {attribute_path} {arrival_time} {partion_id}" - ret = self.send_command(command) - return 0 - - def cycle(self): - ret = self.send_command("cycle") - return int(ret.split(" ")[-1]) - - def until(self, until_cycle): - command = f"until {until_cycle}" - ret = self.send_command(command) - bitmap = int(ret.split(" ")[-1]) - indices = [] - for i in range(64): - if (bitmap >> i) & 1: - indices.append(i) - return indices - - def quit(self): - command = "quit" - ret = self.send_command(command) - return + # Save trace_log with same name but .trace extension + if self.trace_log: + result_path = extension_config.CONFIG_TORCHSIM_LOG_PATH + os.makedirs(result_path, exist_ok=True) + file_name = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + ".trace" + trace_path = os.path.join(result_path, file_name) + with open(trace_path, "w") as f: + f.write(self.trace_log) + logger.info(f'[TOGSim] Trace log is stored to "{trace_path}"') + + # Clean up FIFOs + self._cleanup_fifos() + + def launch_kernel(self, device_index, stream_index, tog_path, attribute_path, timestamp=0): + """ + Launch a kernel via FIFO communication. + + Args: + device_index: Device index + stream_index: Stream index + tog_path: Path to TOG file (ONNX model) + attribute_path: Path to attribute file + timestamp: Timestamp in nanoseconds (default: 0) + + Returns: + int: The kernel ID assigned to this launch + """ + return self._send_command("LAUNCH_KERNEL", device_index, stream_index, tog_path, attribute_path, timestamp) + + def device_synchronize(self, device_index): + """ + Synchronize all streams on a device via FIFO communication. + + Args: + device_index: Device index to synchronize + timestamp: Timestamp in nanoseconds (default: 0) + + Returns: + int: The command ID assigned to this synchronization + """ + # For device_synchronize, stream_index is not meaningful, use 0 + return self._send_command("DEVICE_SYNC", device_index, 0, "", "", 0) @classmethod def sram_alloc(cls, buf_name, addr_range): @@ -339,70 +442,148 @@ def sram_dealloc(cls, buf_name, addr_range): if buf_name in cls.ALLOC_POOL: del cls.ALLOC_POOL[buf_name] - def create_attribute_file(self, attribute_path, inputs, **kwargs): + @staticmethod + def write_kernel_attribute_file(attribute_dir, inputs, alloc_pool=None): + """ + Write kernel attribute YAML (address_info + sram_alloc) under attribute_dir. + + Does not require a TOGSimulator instance. alloc_pool defaults to class ALLOC_POOL. + + Args: + attribute_dir: Directory to hold numbered attribute files (created if needed) + inputs: Kernel input tensors (data_ptr used for address_info) + alloc_pool: Optional dict like ALLOC_POOL; defaults to TOGSimulator.ALLOC_POOL + + Returns: + Path to the written YAML file. + """ + if alloc_pool is None: + alloc_pool = TOGSimulator.ALLOC_POOL address_info = {} sram_buffer = {} - json_content = {} - os.makedirs(attribute_path, exist_ok=True) - index = str(len(os.listdir(attribute_path))) - attribute_path = os.path.join(attribute_path, index) + yaml_content = {} + + os.makedirs(attribute_dir, exist_ok=True) + index = str(len(os.listdir(attribute_dir))) + attribute_file = os.path.join(attribute_dir, index) for idx, tensor in enumerate(inputs): address_info[f"arg{idx}"] = tensor.data_ptr() - json_content["address_info"] = address_info + yaml_content["address_info"] = address_info - for buf_name, range in self.ALLOC_POOL.items(): + for buf_name, range in alloc_pool.items(): sram_buffer[buf_name] = range - json_content["sram_alloc"] = sram_buffer + yaml_content["sram_alloc"] = sram_buffer - with open(attribute_path, "w") as f: - json.dump(json_content, f, indent=4) + with open(attribute_file, "w") as f: + yaml.dump(yaml_content, f, default_flow_style=False) f.flush() - os.fsync(f.fileno()) # There could be a race condition. - return attribute_path + os.fsync(f.fileno()) + return attribute_file - def load_json(self, config_path): + def load_yaml(self, config_path): config_path = Path(config_path) if not config_path.is_file(): - raise FileNotFoundError(f"JSON file not found: {config_path}") + raise FileNotFoundError(f"YAML file not found: {config_path}") try: with open(config_path, "r") as file: - data = json.load(file) + data = yaml.safe_load(file) return data - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON format: {e}") + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML format: {e}") def get_core_freq(self): - if "core_freq_mhz" in self.config_json: - return self.config_json["core_freq_mhz"] * 1000 * 1000 # MHz + if "core_freq_mhz" in self.config_yaml: + return self.config_yaml["core_freq_mhz"] * 1000 * 1000 # MHz else: raise KeyError("Key 'core_freq' not found in JSON.") - def find_zero_sub_tensors(self, tensor): - x, y = self.vectorlane_size, self.vectorlane_size - zero_positions = {} + @staticmethod + def get_togsim_command(config_path, togsim_path=None): + if togsim_path is None: + togsim_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "TOGSim") + bin = os.path.join(togsim_path, "build/bin/Simulator") + config = os.path.join(togsim_path, config_path) + cmd = f"{bin} --config {config}" + return cmd + + @staticmethod + def run_standalone(model_path, attribute_path="", autotune_mode=False, config_path=None, togsim_path=None): + """ + Run a single kernel simulation in standalone mode. + This method starts a new TOGSim process, runs the kernel, and waits for completion. + For streaming multiple kernels, use launch_kernel() instead. + + Args: + model_path: Path to TOG file (ONNX model) + attribute_path: Path to attribute file + autotune_mode: If True, run in autotune mode (silent) + config_path: Path to TOGSim config file (required) + togsim_path: Path to TOGSim directory (optional, defaults to CONFIG_TORCHSIM_DIR/TOGSim) + + Returns: + Path to the simulation result log file + """ + if config_path is None: + config_path = extension_config.CONFIG_TOGSIM_CONFIG + if togsim_path is None: + togsim_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "TOGSim") + + # Create result path with appropriate filename + if autotune_mode: + base_dir = Path(model_path).parent / "togsim_result" + else: + base_dir = Path(extension_config.CONFIG_TORCHSIM_LOG_PATH) + + base_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + file_name = f"{timestamp}_{uuid.uuid4().hex[:8]}" + result_path = base_dir / f"{file_name}.log" + trace_file_path = base_dir / f"{file_name}.trace" + + # Create trace file in result directory + kernel_id, device_index, stream_index, timestamp = 0, 0, 0, 0 + command = f"LAUNCH_KERNEL,{kernel_id},{device_index},{stream_index},{model_path},{attribute_path},{timestamp}\n" + with open(trace_file_path, 'w') as trace_file: + trace_file.write(command) + trace_file.flush() + os.fsync(trace_file.fileno()) + + try: + cmd = f"{TOGSimulator.get_togsim_command(config_path, togsim_path)} --models_list {trace_file_path}" + if extension_config.CONFIG_TOGSIM_DEBUG_LEVEL: + cmd += f" --log_level {extension_config.CONFIG_TOGSIM_DEBUG_LEVEL}" + + if not autotune_mode: + logger.debug(f"[TOGSim] cmd> {cmd}") + logger.info("[TOGSim] TOGSim simulation started") + with ProgressBar("[TOGSim] Running simulation", silent_mode=autotune_mode): + result = subprocess.check_output(shlex.split(cmd)) + except subprocess.CalledProcessError as e: + logger.error(f"[TOGSim] Command failed with exit code {e.returncode}") + logger.error(f"[TOGSim] Error output: {e.output.decode() if isinstance(e.output, bytes) else e.output}") + assert 0 - # Need to set vectorlane size - if self.vectorlane_size == -1: - return zero_positions + # Prevent race condition + with open(result_path, "w") as f: + f.write(result.decode()) + f.flush() + os.fsync(f.fileno()) - for i in range(0, tensor.shape[0], y): - for j in range(0, tensor.shape[1], x): - sub_tensor = tensor[i:i + y, j:j + x] - if np.all(sub_tensor == 0): - if i not in zero_positions: - zero_positions[i] = {} - zero_positions[i][j] = 0 # i pos : j pos : 0 - return zero_positions + if not autotune_mode: + import logging as _logging + model_path_log = f' of "{model_path}" ' if logger.isEnabledFor(_logging.DEBUG) else " " + logger.info(f'[TOGSim] Simulation log{model_path_log}is stored to "{result_path}"') + return result_path @staticmethod def get_result_from_file(result_path): core_metrics = {} dram_channel_bw = {} - avg_dram_bw = None - simulation_time = None - total_cycle = None + avg_dram_bw = 0.0 + simulation_time = float("inf") + total_cycle = float("inf") # Read and find total stat position with open(result_path, "r") as f: @@ -417,7 +598,7 @@ def get_result_from_file(result_path): break if simulation_finished_idx == -1: - print("[TOGSim] Tried to parsing wrong formated output file!") + logger.warning(f"[TOGSim] Warning: Unable to parse the output file ({result_path}). The file may be improperly formatted.") return core_metrics, dram_channel_bw, avg_dram_bw, simulation_time total_stat_lines = lines[simulation_finished_idx:] @@ -457,6 +638,24 @@ def get_result_from_file(result_path): return core_metrics, dram_channel_bw, avg_dram_bw, simulation_time, total_cycle if __name__ == "__main__": - sim = TOGSimulator("/workspace/PyTorchSim/TOGSim", "/workspace/PyTorchSim/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.json") - sim.interactive_simulation() - sim.until(4000) \ No newline at end of file + # Example paths (adjust these to your actual test files) + test_tog_path = "/workspace/PyTorchSim/outputs/6vxl6mwzhfl/tile_graph.onnx" + test_attribute_path = "/workspace/PyTorchSim/outputs/6vxl6mwzhfl/runtime_0001/attribute/0" + + # Test: Launch multiple kernels + sim = TOGSimulator(config_path="/workspace/PyTorchSim/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml") + with sim: + try: + id1 = torch.npu.launch_kernel(tog_path=test_tog_path, attribute_path=test_attribute_path) + id2 = torch.npu.launch_kernel(tog_path=test_tog_path, attribute_path=test_attribute_path) + id3 = torch.npu.launch_kernel(tog_path=test_tog_path, attribute_path=test_attribute_path) + except Exception as e: + print(f"Error during kernel launch: {e}") + + try: + id2 = torch.npu.launch_kernel(tog_path=test_tog_path, attribute_path=test_attribute_path) + id1 = torch.npu.launch_kernel(tog_path=test_tog_path, attribute_path=test_attribute_path) + id3 = torch.npu.launch_kernel(tog_path=test_tog_path, attribute_path=test_attribute_path) + except Exception as e: + print(f"Error during kernel launch: {e}") + print(sim.trace_log) \ No newline at end of file diff --git a/TOGSim/conanfile.txt b/TOGSim/conanfile.txt index 7a57f52f..ce5268c7 100644 --- a/TOGSim/conanfile.txt +++ b/TOGSim/conanfile.txt @@ -2,6 +2,6 @@ boost/1.79.0 robin-hood-hashing/3.11.5 spdlog/1.11.0 -nlohmann_json/3.11.2 +yaml-cpp/0.8.0 [generators] cmake diff --git a/TOGSim/extern/ramulator2 b/TOGSim/extern/ramulator2 index 748cd709..0a893236 160000 --- a/TOGSim/extern/ramulator2 +++ b/TOGSim/extern/ramulator2 @@ -1 +1 @@ -Subproject commit 748cd7099778d7196326aeb6384da92efb0c34c9 +Subproject commit 0a89323664f5767f544aa280456efc8a807504d0 diff --git a/TOGSim/include/Common.h b/TOGSim/include/Common.h index 640cba0c..2fd62681 100644 --- a/TOGSim/include/Common.h +++ b/TOGSim/include/Common.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -14,7 +15,6 @@ #include "SimulationConfig.h" #include "Instruction.h" -#include "nlohmann/json.hpp" #define MIN(x, y) (((x) > (y)) ? (y) : (x)) #define MIN3(x, y, z) MIN(MIN(x, y), z) @@ -24,10 +24,8 @@ #define PAGE_SIZE 4096 -using json = nlohmann::json; - typedef uint64_t addr_type; typedef uint64_t cycle_type; -uint32_t generate_id(); -SimulationConfig initialize_config(json config); \ No newline at end of file +bool loadConfig(const std::string& config_path, YAML::Node& config_yaml); +SimulationConfig initialize_config(YAML::Node config); \ No newline at end of file diff --git a/TOGSim/include/Core.h b/TOGSim/include/Core.h index e4d2f30a..286feb5f 100644 --- a/TOGSim/include/Core.h +++ b/TOGSim/include/Core.h @@ -10,6 +10,14 @@ #include "Tile.h" #include "SimulationConfig.h" #include "DMA.h" +#include "TraceLogTags.h" + +/** Log tag kind for Core::finish_instruction (see TraceLogTag names in TraceLogTags.h). */ +enum class InstFinishTraceTag { + Fnshed, + DmaIssueComplete, + DmaRespComplete, +}; class Core { public: @@ -22,7 +30,8 @@ class Core { virtual void cycle(); virtual void print_stats(); virtual void print_current_stats(); - virtual void finish_instruction(std::shared_ptr& inst); + virtual void finish_instruction(std::shared_ptr& inst, + InstFinishTraceTag tag = InstFinishTraceTag::Fnshed); virtual bool has_memory_request(); virtual void pop_memory_request(); virtual mem_fetch* top_memory_request() { return _request_queue.front(); } diff --git a/TOGSim/include/CoreTraceLog.h b/TOGSim/include/CoreTraceLog.h new file mode 100644 index 00000000..e78c1ef2 --- /dev/null +++ b/TOGSim/include/CoreTraceLog.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include + +#include "Instruction.h" +#include "TraceLogTags.h" + +/** + * Instruction / tile trace formatting and Core spdlog::trace helpers. + * Keeps Core.cc focused on simulation logic. + */ +namespace core_trace_log { + +std::string format_dma_inst_issued_detail(Instruction& inst); +/** Opcode + (detail...) for DMA issue / skip traces. */ +std::string format_dma_inst_issued_trace_line(Instruction& inst); +/** Opcode + (detail...) for COMP / BAR / MOVIN / MOVOUT finished or issued lines. */ +std::string format_instruction_detail_line(Instruction& inst); + +void trace_tile_scheduled(cycle_type core_cycle, uint32_t core_id, const std::string& tag15); + +void trace_instruction_line(cycle_type core_cycle, + uint32_t core_id, + const std::string& tag15, + uint64_t global_inst_id, + const std::string& message); + +void log_error_dma_instruction_invalid(cycle_type core_cycle, uint32_t core_id); +void log_error_dram_responses_trace_not_finished(cycle_type core_cycle, uint32_t core_id); +void log_error_instruction_already_finished(cycle_type core_cycle, + uint32_t core_id, + const std::string& opcode_name); +void log_error_undefined_opcode(); + +} // namespace core_trace_log diff --git a/TOGSim/include/DMA.h b/TOGSim/include/DMA.h index 2f41c6f3..08bdcab4 100644 --- a/TOGSim/include/DMA.h +++ b/TOGSim/include/DMA.h @@ -12,41 +12,41 @@ #include "Memfetch.h" struct VectorCompare { - bool operator()(const std::vector& a, const std::vector& b) const { + bool operator()(const std::vector& a, const std::vector& b) const { return a < b; } }; class DMA { public: - DMA(uint32_t id, uint32_t dram_req_size); + DMA(uint32_t id, uint32_t dram_req_size, bool l2_datacache_enabled); void issue_tile(std::shared_ptr inst); bool is_finished() { return _finished; } bool empty() { return _current_inst==nullptr; } - void register_tag(int subgraph_id, std::vector& key) { + void register_tag(int subgraph_id, std::vector& key) { if (tag_table.find(subgraph_id) == tag_table.end()) { - tag_table[subgraph_id] = std::map, uint32_t>(); - waiters[subgraph_id] = std::map, std::vector>>(); + tag_table[subgraph_id] = std::map, uint32_t>(); + waiters[subgraph_id] = std::map, std::vector>>(); } tag_table[subgraph_id][key] = 0; waiters[subgraph_id][key] = std::vector>(); } - void set_tag_finish(int subgraph_id, std::vector& key) { + void set_tag_finish(int subgraph_id, std::vector& key) { if (tag_table.find(subgraph_id) == tag_table.end()) { throw std::runtime_error("Subgraph does not exist in tag_table"); } tag_table[subgraph_id][key] = 1; } - void set_tag_sparse(int subgraph_id, std::vector& key) { + void set_tag_sparse(int subgraph_id, std::vector& key) { if (tag_table.find(subgraph_id) == tag_table.end()) { throw std::runtime_error("Subgraph does not exist in tag_table"); } tag_table[subgraph_id][key] = -1; } - void mark_tag_used(int subgraph_id, std::vector& key) { + void mark_tag_used(int subgraph_id, std::vector& key) { if (tag_table.find(subgraph_id) == tag_table.end()) { throw std::runtime_error("Subgraph does not exist in tag_table"); } else if (!tag_table[subgraph_id][key]) { @@ -59,17 +59,17 @@ class DMA { for (const auto& entry: tag_table) { auto subgraph_id = entry.first; for (const auto& tag_entry: tag_table[subgraph_id]) { - const std::vector& tag_key = tag_entry.first; + const std::vector& tag_key = tag_entry.first; uint32_t value = tag_entry.second; if (value == 1) { - spdlog::warn("[Tag Table][{}] Unused tag found: (key={}, val={})", + spdlog::debug("[Tag Table][{}] Unused tag found: (key={}, val={})", subgraph_id, fmt::format("[{}]", fmt::join(tag_key, ", ")), value); } } } } - bool tag_key_exist(int subgraph_id, std::vector& key) { + bool tag_key_exist(int subgraph_id, std::vector& key) { auto subgraph_it = tag_table.find(subgraph_id); if (subgraph_it == tag_table.end()) return false; @@ -78,7 +78,7 @@ class DMA { auto key_it = key_map.find(key); return key_it != key_map.end(); } - uint32_t get_tag_finish(int subgraph_id, std::vector& key) { + uint32_t get_tag_finish(int subgraph_id, std::vector& key) { auto subgraph_it = tag_table.find(subgraph_id); auto& key_map = subgraph_it->second; auto key_it = key_map.find(key); @@ -95,7 +95,7 @@ class DMA { tag_table.erase(subgraph_id); waiters.erase(subgraph_id); } - void register_tag_waiter(int subgraph_id, std::vector& key, std::shared_ptr inst) { + void register_tag_waiter(int subgraph_id, std::vector& key, std::shared_ptr inst) { auto subgraph_it = tag_table.find(subgraph_id); auto& key_map = subgraph_it->second; auto key_it = key_map.find(key); @@ -104,7 +104,7 @@ class DMA { } waiters[subgraph_id][key].push_back(inst); } - std::vector>& get_tag_waiter(int subgraph_id, std::vector& key) { + std::vector>& get_tag_waiter(int subgraph_id, std::vector& key) { auto subgraph_it = tag_table.find(subgraph_id); auto& key_map = subgraph_it->second; auto key_it = key_map.find(key); @@ -129,9 +129,10 @@ class DMA { size_t _tile_idx_stride=1; uint32_t _tile_idx; bool _finished=true; - std::map, uint32_t>> tag_table; - std::map, std::vector>>> waiters; + bool _l2_datacache_enabled = false; + std::map, uint32_t>> tag_table; + std::map, std::vector>>> waiters; std::queue _pending_accesses; bool _generated_once = false; }; -#endif \ No newline at end of file +#endif diff --git a/TOGSim/include/Dram.h b/TOGSim/include/Dram.h index d28ac25f..978bcdf9 100644 --- a/TOGSim/include/Dram.h +++ b/TOGSim/include/Dram.h @@ -29,6 +29,8 @@ class Dram { virtual void print_stat() {} virtual void print_cache_stats() {}; uint32_t get_channels_per_partition() { return _n_ch_per_partition; } + new_addr_type partition_dram_address(new_addr_type raw_addr) const; + protected: SimulationConfig _config; CacheConfig _m_cache_config; @@ -37,6 +39,7 @@ class Dram { uint32_t _n_partitions; uint32_t _n_ch_per_partition; uint32_t _req_size; + int _tx_log2 = 0; cycle_type _cycles; cycle_type* _core_cycles; std::vector> m_cache_latency_queue; @@ -83,8 +86,6 @@ class SimpleDRAM: public Dram { void print_cache_stats() override; private: int _latency = 1; - int _tx_ch_log2; - int _tx_log2; std::vector>> _mem; }; diff --git a/TOGSim/include/Instruction.h b/TOGSim/include/Instruction.h index 9fad13f4..bb62a440 100644 --- a/TOGSim/include/Instruction.h +++ b/TOGSim/include/Instruction.h @@ -18,13 +18,14 @@ typedef uint64_t addr_type; typedef uint64_t cycle_type; std::string opcode_to_string(Opcode opcode); +std::string format_tag_key_list_hex(const std::vector& tag_keys); class Instruction : public std::enable_shared_from_this { public: Instruction(Opcode opcode, cycle_type compute_cycle, size_t num_parents, addr_type dram_addr, - std::vector tile_size, std::vector tile_stride, size_t precision, - std::vector tag_idx_list, std::vector tag_stride_list, - std::vector accum_tag_idx_list); + std::vector tile_size, std::vector tile_stride, size_t elem_bits, + std::vector tag_idx_list, std::vector tag_stride_list, + std::vector accum_tag_idx_list); Instruction(Opcode opcode); void finish_instruction(); void add_child(std::shared_ptr child); @@ -32,6 +33,7 @@ class Instruction : public std::enable_shared_from_this { const Opcode get_opcode() { return opcode; } bool is_dma_read() { return opcode == Opcode::MOVIN; } bool is_dma_write() { return opcode == Opcode::MOVOUT; } + bool is_dma_instruction() const { return opcode == Opcode::MOVIN || opcode == Opcode::MOVOUT; } bool is_async_dma() { return _is_async_dma; } bool is_indirect_mode() { return _is_indirect_mode; } std::string get_indirect_index_path() { return _indirect_index_path; } @@ -45,11 +47,12 @@ class Instruction : public std::enable_shared_from_this { } } size_t get_tile_numel() { return _tile_numel; } - size_t get_precision() { return _precision; } + size_t get_elem_bits() const { return _elem_bits; } void inc_waiting_request(); void dec_waiting_request(); size_t get_waiting_request() { return _nr_waiting_request; } std::vector& get_tile_size() { return tile_size; } + std::vector& get_tile_stride() { return tile_stride; } void set_overlapping_cycle(cycle_type cycle) { overlapping_cycle = cycle; } cycle_type get_overlapping_cycle() { return overlapping_cycle; } cycle_type get_compute_cycle() { return compute_cycle; } @@ -68,12 +71,12 @@ class Instruction : public std::enable_shared_from_this { int get_compute_type() { return _compute_type; } void set_numa_id(int numa_id) { _numa_id = numa_id; } uint32_t get_numa_id() { return _numa_id; } - std::vector& get_tag_idx_list() { return _tag_idx_list; } - std::vector& get_tag_stride_list() { return _tag_stride_list; } - std::vector& get_tag_id() { return _tag_key; } - void set_addr_name(std::string name, int id) { _addr_name = name; _addr_id = id; } + std::vector& get_tag_idx_list() { return _tag_idx_list; } + std::vector& get_tag_stride_list() { return _tag_stride_list; } + std::vector& get_tag_id() { return _tag_key; } + void set_addr_name(std::string name, int64_t id) { _addr_name = name; _addr_id = id; } std::string get_addr_name() { return _addr_name; } - int get_addr_id() { return _addr_id; } + int64_t get_addr_id() { return _addr_id; } void set_nr_inner_loop(int nr) { _nr_inner_loop = nr; } int get_nr_inner_loop() { return _nr_inner_loop; } void set_is_async(bool is_async) { _is_async_dma = is_async; } @@ -81,6 +84,7 @@ class Instruction : public std::enable_shared_from_this { bool is_sparse_inst() { return _is_sparse_inst; } void set_sparse_state(bool state) { _is_sparse_inst = state; } std::set>& get_child_inst() { return child_inst; } + uint64_t get_global_inst_id() const { return _global_inst_id; } cycle_type start_cycle; cycle_type finish_cycle; @@ -89,6 +93,9 @@ class Instruction : public std::enable_shared_from_this { bool finished=false; int subgraph_id; private: + uint64_t _global_inst_id = 0; + static uint64_t _next_global_inst_id; + void *_owner = nullptr; std::list>* _owner_ready_queue_ref = nullptr; Opcode opcode; @@ -100,17 +107,17 @@ class Instruction : public std::enable_shared_from_this { std::vector tile_stride; size_t _tile_numel; size_t _nr_waiting_request=0; - size_t _precision=0; + size_t _elem_bits = 0; addr_type dram_addr; uint32_t _numa_id = 0; // For DMA instruction int _compute_type = 0; - std::vector _tag_idx_list; - std::vector _tag_stride_list; - std::vector _tag_key; - std::vector _accum_tag_idx_list; + std::vector _tag_idx_list; + std::vector _tag_stride_list; + std::vector _tag_key; + std::vector _accum_tag_idx_list; std::vector _trace_address; std::string _addr_name; - int _addr_id; + int64_t _addr_id = 0; int _nr_inner_loop = 0; bool _is_async_dma=false; bool _is_indirect_mode=false; diff --git a/TOGSim/include/SimulationConfig.h b/TOGSim/include/SimulationConfig.h index 64cfa223..090f5520 100644 --- a/TOGSim/include/SimulationConfig.h +++ b/TOGSim/include/SimulationConfig.h @@ -1,13 +1,11 @@ #pragma once -#include #include - -using json = nlohmann::json; +#include enum class CoreType { WS_MESH, STONNE }; -enum class DramType { SIMPLE, RAMULATOR1, RAMULATOR2 }; +enum class DramType { SIMPLE, RAMULATOR2 }; enum class IcntType { SIMPLE, BOOKSIM2 }; diff --git a/TOGSim/include/Simulator.h b/TOGSim/include/Simulator.h index 39fa310e..a0b8b9c5 100644 --- a/TOGSim/include/Simulator.h +++ b/TOGSim/include/Simulator.h @@ -24,8 +24,15 @@ namespace fs = std::filesystem; class Simulator { public: Simulator(SimulationConfig config); - void schedule_graph(int partion_id, std::unique_ptr tile_graph) { - _partition_scheduler.at(partion_id)->schedule_graph(std::move(tile_graph)); + void enqueue_graph(int partion_id, std::unique_ptr tile_graph) { + if (partion_id < 0 || static_cast(partion_id) >= _config.num_partition) { + spdlog::error("[Enqueue_graph] Invalid partition_id: {} (valid range: 0 to {}). " + "Total partitions: {}", partion_id, _config.num_partition - 1, _config.num_partition); + throw std::runtime_error( + fmt::format("[Enqueue_graph] Invalid partition_id: {} (valid range: 0 to {}). " + "Total partitions: {}", partion_id, _config.num_partition - 1, _config.num_partition)); + } + _partition_scheduler.at(partion_id)->enqueue_graph(std::move(tile_graph)); } void run_simulator(); cycle_type get_core_cycle() { return _core_cycles; } diff --git a/TOGSim/include/SparseCore.h b/TOGSim/include/SparseCore.h index 9188b21d..a91004ed 100644 --- a/TOGSim/include/SparseCore.h +++ b/TOGSim/include/SparseCore.h @@ -1,5 +1,6 @@ #include #include +#include #include "Core.h" #include "sstStonne.h" #include "SimpleMem.h" @@ -58,7 +59,8 @@ class SparseCore : public Core { void print_stats() override; void print_current_stats() override; std::shared_ptr pop_finished_tile() override; - void finish_instruction(std::shared_ptr& inst) override; + void finish_instruction(std::shared_ptr& inst, + InstFinishTraceTag tag = InstFinishTraceTag::Fnshed) override; void dumpTrace(int stonne_core_id, const std::string& path); bool isTraceMode(int stonne_core_id) { return traceMode.at(stonne_core_id); } void setTraceMode(int stonne_core_id, bool mode) { traceMode.at(stonne_core_id) = mode; } diff --git a/TOGSim/include/TileGraph.h b/TOGSim/include/TileGraph.h index 990c107d..869bbb11 100644 --- a/TOGSim/include/TileGraph.h +++ b/TOGSim/include/TileGraph.h @@ -4,21 +4,33 @@ #include #include #include +#include #include "Tile.h" +#include "Common.h" #include "IntervalTree.h" +class TileGraph; + class TileSubGraph { public: TileSubGraph(); void add_tile(std::shared_ptr tile); void finish_tile(std::shared_ptr tile); - bool is_finished() { return _ready_tile_queue.empty() && _tile_set.empty(); } + /** True when no tile is executing on a core and no work remains in this subgraph. */ + bool is_finished() { + return _in_flight_tiles == 0 && _ready_tile_queue.empty() && _tile_set.empty(); + } + void on_tile_issued(); + void add_parallel_buffer(void* ptr); + void release_parallel_buffers(); const std::shared_ptr peek_tile(); std::shared_ptr get_tile(); int get_id() { return _id; } void set_core_id(int core_id) { _core_id = core_id; } int get_core_id() { return _core_id; } void init_cache_plan(std::shared_ptr> plan) { _cache_plan = plan; } + void set_owner_tile_graph(std::shared_ptr g); + std::shared_ptr lock_owner_tile_graph() const; bool is_cacheable(unsigned long long start, unsigned long long end) { return _cache_plan->findOverlapping(start, end).size() != 0; } struct CompareReadyTile { bool operator()(const std::shared_ptr& a, const std::shared_ptr& b) const { @@ -29,16 +41,21 @@ class TileSubGraph { protected: std::priority_queue, std::vector>, CompareReadyTile> _ready_tile_queue; std::set> _tile_set; + int _in_flight_tiles = 0; + std::vector _parallel_buffers; int _id; int _core_id = -1; static int _next_id; std::shared_ptr> _cache_plan; + std::weak_ptr _owner_tile_graph; }; -class TileGraph { +class TileGraph : public std::enable_shared_from_this { public: TileGraph(std::string path, std::string name) : _path(path), _name(name), _subgraph_vec(), _cpu_graph_map() {} void append_subgraph(std::shared_ptr subgraph); + /** Call once the TileGraph is owned by shared_ptr (e.g. from Scheduler::enqueue_graph). */ + void wire_subgraph_owner_links(); bool empty(int core_id) { if (_vec_index != _subgraph_vec.size()) { return false; @@ -56,6 +73,9 @@ class TileGraph { return true; } bool is_finished(); + /** Idempotent: logs kernel completion once at the simulated cycle when the graph is fully finished. */ + void try_emit_kernel_complete(cycle_type at_cycle, int scheduler_partition_id = -1); + bool kernel_complete_logged() const { return _kernel_complete_logged; } const std::shared_ptr peek_tile(int core_id, int slot_id); std::shared_ptr get_tile(int core_id, int slot_id); void allocate_subgraph(int core_id, int slot_id); @@ -67,6 +87,10 @@ class TileGraph { std::string get_name() { return _name; } void set_arrival_time(cycle_type arrival_time) { _arrival_time = arrival_time; } cycle_type get_arrival_time() { return _arrival_time; } + void set_kernel_id(unsigned int kernel_id) { _kernel_id = kernel_id; } + unsigned int get_kernel_id() { return _kernel_id; } + void set_start_time(cycle_type start_time) { _start_time = start_time; } + cycle_type get_start_time() { return _start_time; } void init_cache_plan(IntervalTree::interval_vector it) { _cache_plan = std::make_shared>(std::move(it)); } @@ -130,6 +154,7 @@ class TileGraph { int _vec_index=0; std::string _path; std::string _name = "?"; + unsigned int _kernel_id = 0; std::vector _loop_index_list; std::vector> _ranges; std::vector> _subgraph_vec; @@ -137,5 +162,7 @@ class TileGraph { std::map>> _cpu_graph_map; std::shared_ptr> _cache_plan; cycle_type _arrival_time; + cycle_type _start_time = 0; // First tile issue time, 0 means not started yet + bool _kernel_complete_logged = false; static std::shared_ptr null_tile; }; \ No newline at end of file diff --git a/TOGSim/include/TileGraphParser.h b/TOGSim/include/TileGraphParser.h index 9cc61d4a..d255a735 100644 --- a/TOGSim/include/TileGraphParser.h +++ b/TOGSim/include/TileGraphParser.h @@ -2,19 +2,18 @@ #include #include #include -#include +#include #include #include #include "TileGraph.h" #include "Instruction.h" #include "sstStonne.h" #include "IntervalTree.h" +#include "Common.h" #include "onnx/defs/schema.h" #include "onnx/onnx-operators_pb.h" #include "onnx/onnx_pb.h" -using json = nlohmann::json; - enum class TileType{ LOOP_INDEX_NODE, LOOP_END_NODE, @@ -35,8 +34,6 @@ enum class LoopType { INNER_LOOP }; -bool loadConfig(const std::string& config_path, json& config_json); - class TileNode { public: TileNode(onnx::NodeProto& node); @@ -68,7 +65,7 @@ class TileNode { class TileGraphParser { public: - TileGraphParser(std::string onnx_path, std::string attribute_path, std::string config_path); + TileGraphParser(std::string onnx_path, std::string attribute_path, const YAML::Node& config_yaml); std::shared_ptr get_top_loop(); std::unique_ptr& get_tile_graph() { return _tile_graph; } addr_type lookup(std::string key); @@ -80,12 +77,12 @@ class TileGraphParser { LoopType get_loop_type(std::string key) { return std::get<2>(_loop_size_map[key]); } const std::map> & get_loop_map() { return _loop_size_map; } const std::vector &lookupNumaInfo(std::string key); - int getCoreIdFromJson(const json& attribute_json, int subgraph_id); + int getCoreIdFromConfig(const YAML::Node& attribute_config, int subgraph_id); std::string getMetaByName(std::string key) { return _tog_meta[key]; } - const json& get_attribute_file() { return _attribute_json; } - std::vector calc_tag(std::vector& accum_tag, std::vector& tag_idx, std::vector& tag_stride); - void register_memory_tag(std::string name, std::vector& tag_key); - bool check_memory_tag(std::string name, std::vector& tag_key); + const YAML::Node& get_attribute_file() { return _attribute_config; } + std::vector calc_tag(std::vector& accum_tag, std::vector& tag_idx, std::vector& tag_stride); + void register_memory_tag(std::string name, std::vector& tag_key); + bool check_memory_tag(std::string name, std::vector& tag_key); void clear_tag_table() { _tag_table.clear(); } std::string get_indirect_path() { namespace fs = std::filesystem; @@ -121,12 +118,12 @@ class TileGraphParser { uint64_t get_dma_counter() { return dma_counter; } void inc_dma_counter() { dma_counter++; } bool is_sparse_tile(uint64_t idx) { return sparse_tile_set.find(idx) != sparse_tile_set.end(); } - int register_addr_name(const std::string& addr_name) { + int64_t register_addr_name(const std::string& addr_name) { if (_addr_name_map.find(addr_name) == _addr_name_map.end()) - _addr_name_map[addr_name] = _addr_name_map.size(); + _addr_name_map[addr_name] = static_cast(_addr_name_map.size()); return _addr_name_map[addr_name]; } - int get_addr_name_id(const std::string& addr_name) { return _addr_name_map[addr_name]; } + int64_t get_addr_name_id(const std::string& addr_name) { return _addr_name_map[addr_name]; } private: void register_tile(std::shared_ptr tile_node); @@ -135,8 +132,8 @@ class TileGraphParser { void _tile_index_generate() {} int _loop_stack_pointer = 0; - json _attribute_json; - json _config_json; + YAML::Node _attribute_config; + YAML::Node _config_yaml; std::string _tog_path; std::string _attribute_path; uint64_t indirect_counter = 0; @@ -151,8 +148,8 @@ class TileGraphParser { std::vector> _cache_plan; std::map> _loop_size_map; std::map _tog_meta; - std::map>, uint32_t> _tag_table; - std::unordered_map _addr_name_map; + std::map>, uint32_t> _tag_table; + std::unordered_map _addr_name_map; }; class TileComputeNode : public TileNode { @@ -174,11 +171,11 @@ class TileMemoryNode : public TileNode { public: TileMemoryNode(onnx::NodeProto& node); std::string get_base_addr_name() { return _base_addr_name; } - size_t get_precision() { return _element_size; } + size_t get_elem_bits() const { return _elem_bits; } std::vector get_tile_size() { return _tile_size; } std::vector& get_tile_stride() { return _tile_stride; } std::vector& get_tag_idx_list() { return _tag_idx_list; } - std::vector& get_tag_stride_list() { return _tag_stride_list; } + std::vector& get_tag_stride_list() { return _tag_stride_list; } std::vector& get_loop_idx_list() { return _loop_idx_list; } std::vector& get_loop_stride_list () { return _loop_stride_list; } bool is_async_node() { return _is_async; } @@ -188,12 +185,12 @@ class TileMemoryNode : public TileNode { private: std::vector _tile_size; std::vector _tile_stride; - size_t _element_size; + size_t _elem_bits = 0; bool _is_async; bool _is_indirect; std::string _base_addr_name; std::vector _tag_idx_list; - std::vector _tag_stride_list; + std::vector _tag_stride_list; std::vector _loop_idx_list; std::vector _loop_stride_list; }; @@ -203,14 +200,14 @@ class TileMemoryWaitNode : public TileNode { TileMemoryWaitNode(onnx::NodeProto& node); std::string get_base_addr_name() { return _base_addr_name; } std::vector& get_tag_idx_list() { return _tag_idx_list; } - std::vector& get_tag_stride_list() { return _tag_stride_list; } - std::vector& get_tag_divider_list() { return _tag_divider_list; } + std::vector& get_tag_stride_list() { return _tag_stride_list; } + std::vector& get_tag_divider_list() { return _tag_divider_list; } void print_node() override; private: std::vector _tag_idx_list; - std::vector _tag_stride_list; - std::vector _tag_divider_list; + std::vector _tag_stride_list; + std::vector _tag_divider_list; std::string _base_addr_name; }; diff --git a/TOGSim/include/TraceLogTags.h b/TOGSim/include/TraceLogTags.h new file mode 100644 index 00000000..6c158099 --- /dev/null +++ b/TOGSim/include/TraceLogTags.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include + +/** Trace bracket tags: max 15 characters; use pad15() so logs show a fixed 15-char field (space-padded). */ +namespace TraceLogTag { + +/** Right-pad (or truncate) to exactly 15 characters for aligned log columns. */ +inline std::string pad15(std::string_view sv) { + if (sv.size() > 15) { + sv = sv.substr(0, 15); + } + std::string out(sv); + out.resize(15, ' '); + return out; +} + +inline constexpr const char* kTileScheduled = "TILE_SCHEDULED"; + +inline constexpr const char* kInstructionIssued = "INST_ISSUED"; +inline constexpr const char* kInstructionFinished = "INST_FINISHED"; +/** Async MOVIN skipped: same tag still in flight. */ +inline constexpr const char* kInstructionSkipped = "INST_SKIP"; + +inline constexpr const char* kAsyncDmaAllRequestsIssued = "ASYNC_DMA_ISSUE"; +inline constexpr const char* kAllDramResponsesReceived = "DRAM_RESP_DONE"; + +inline constexpr const char* kL2CacheableStatusForAddress = "L2CACHE_STAT"; +inline constexpr const char* kDmaNumaPlacement = "DRAM_NUMA"; + +/** Field label for get_global_inst_id() in trace lines (≤15 chars). */ +inline constexpr const char* kGlobalInstIdKey = "INST_ID"; +} // namespace TraceLogTag diff --git a/TOGSim/include/scheduler/Scheduler.h b/TOGSim/include/scheduler/Scheduler.h index 39ab7576..3cdf5b2e 100644 --- a/TOGSim/include/scheduler/Scheduler.h +++ b/TOGSim/include/scheduler/Scheduler.h @@ -1,5 +1,8 @@ #pragma once #include +#include +#include +#include #include "Tile.h" #include "Common.h" #include "TileGraph.h" @@ -8,8 +11,8 @@ class Scheduler { public: Scheduler(SimulationConfig config, const cycle_type* core_cycle, const uint64_t* core_time, int id); - void schedule_graph(std::unique_ptr tile_graph); - void finish_tile(std::shared_ptr tile) { tile->get_owner()->finish_tile(tile); } + void enqueue_graph(std::unique_ptr tile_graph); + void finish_tile(std::shared_ptr tile); /* For other schedulers */ virtual std::shared_ptr get_tile(int core_id=0, int slot_id=0); @@ -22,7 +25,10 @@ class Scheduler { int _id; const cycle_type* _core_cycle; const uint64_t* _core_time; - std::deque> _tile_graph; + /** Scheduling queue (front = current kernel for issue). */ + std::deque> _tile_graph; + /** Keeps TileGraph alive until kernel completion is logged (may extend past pop from _tile_graph). */ + std::unordered_set> _in_flight_graphs; SimulationConfig _config; struct CompareTile { diff --git a/TOGSim/src/Common.cc b/TOGSim/src/Common.cc index 9a6b7798..b15381a6 100644 --- a/TOGSim/src/Common.cc +++ b/TOGSim/src/Common.cc @@ -1,28 +1,41 @@ #include "Common.h" -uint32_t generate_id() { - static uint32_t id_counter{0}; - return id_counter++; +bool loadConfig(const std::string& config_path, YAML::Node& config_yaml) { + try { + config_yaml = YAML::LoadFile(config_path); + spdlog::info("[LoadConfig] Success to open \"{}\"", config_path); + return true; + } catch (const YAML::BadFile& e) { + spdlog::error("[LoadConfig] Failed to open \"{}\" (File not found or inaccessible)", config_path); + return false; + } catch (const YAML::ParserException& e) { + spdlog::error("[LoadConfig] Failed to parse YAML file \"{}\": {}", config_path, e.what()); + return false; + } catch (const std::exception& e) { + spdlog::error("[LoadConfig] Unknown error loading \"{}\": {}", config_path, e.what()); + return false; + } } template -T get_config_value(json config, std::string key) { - if (config.contains(key)) { - return config[key]; +T get_config_value(const YAML::Node& config, std::string key) { + if (config[key]) { + return config[key].as(); } else { throw std::runtime_error(fmt::format("Config key {} not found", key)); } } -SimulationConfig initialize_config(json config) { +SimulationConfig initialize_config(YAML::Node config) { SimulationConfig parsed_config; - // print json - spdlog::info("TOGSim Config: {}", config.dump(2)); + YAML::Emitter emitter; + emitter << config; + spdlog::info("PyTorchSim config:\n{}", emitter.c_str()); /* Core configs */ - parsed_config.num_cores = config["num_cores"]; - if (config.contains("core_type")) { - std::vector core_types = config["core_type"].get>(); + parsed_config.num_cores = get_config_value(config, "num_cores"); + if (config["core_type"]) { + std::vector core_types = config["core_type"].as>(); if (core_types.size() != parsed_config.num_cores) throw std::runtime_error("Mismatch between num_cores and core_type list size"); @@ -41,100 +54,105 @@ SimulationConfig initialize_config(json config) { for (int i=0; i(config, "core_freq_mhz"); + if (config["num_systolic_array_per_core"]) + parsed_config.num_systolic_array_per_core = config["num_systolic_array_per_core"].as(); + if (config["num_stonne_per_core"]) + parsed_config.num_stonne_per_core = config["num_stonne_per_core"].as(); + if (config["num_stonne_port"]) + parsed_config.num_stonne_port = config["num_stonne_port"].as(); parsed_config.core_print_interval = get_config_value(config, "core_stats_print_period_cycles"); - /* Stonne config */ - if (config.contains("stonne_config_path")) - parsed_config.stonne_config_path = config["stonne_config_path"]; + /* Stonne config */ + if (config["stonne_config_path"]) + parsed_config.stonne_config_path = config["stonne_config_path"].as(); /* DRAM config */ - if ((std::string)config["dram_type"] == "simple") + std::string dram_type_str = get_config_value(config, "dram_type"); + + if (dram_type_str == "simple") { parsed_config.dram_type = DramType::SIMPLE; - else if ((std::string)config["dram_type"] == "ramulator") - parsed_config.dram_type = DramType::RAMULATOR1; - else if ((std::string)config["dram_type"] == "ramulator2") + parsed_config.dram_latency = get_config_value(config, "dram_latency"); + } else if (dram_type_str == "ramulator2") { parsed_config.dram_type = DramType::RAMULATOR2; - else - throw std::runtime_error(fmt::format("Not implemented dram type {} ", - (std::string)config["dram_type"])); - parsed_config.dram_freq_mhz = config["dram_freq_mhz"]; - if (config.contains("dram_latency")) - parsed_config.dram_latency = config["dram_latency"]; - if (config.contains("ramulator_config_path")) - parsed_config.dram_config_path = config["ramulator_config_path"]; - parsed_config.dram_channels = config["dram_channels"]; - if (config.contains("dram_req_size_byte")) - parsed_config.dram_req_size = config["dram_req_size_byte"]; - if (config.contains("dram_stats_print_period_cycles")) - parsed_config.dram_print_interval = config["dram_stats_print_period_cycles"]; - if(config.contains("dram_num_burst_length")) - parsed_config.dram_nbl = config["dram_num_burst_length"]; - if (config.contains("dram_num_partitions")) { - parsed_config.dram_num_partitions = config["dram_num_partitions"]; + parsed_config.dram_config_path = get_config_value(config, "ramulator_config_path"); + } else { + throw std::runtime_error(fmt::format("Not implemented dram type {} ", dram_type_str)); + } + + parsed_config.dram_freq_mhz = get_config_value(config, "dram_freq_mhz"); + parsed_config.dram_channels = get_config_value(config, "dram_channels"); + parsed_config.dram_req_size = get_config_value(config, "dram_req_size_byte"); + parsed_config.dram_nbl = get_config_value(config, "dram_num_burst_length"); + + if (config["dram_stats_print_period_cycles"]) + parsed_config.dram_print_interval = config["dram_stats_print_period_cycles"].as(); + if (config["dram_num_partitions"]) { + parsed_config.dram_num_partitions = config["dram_num_partitions"].as(); if (parsed_config.dram_channels % parsed_config.dram_num_partitions != 0) { throw std::runtime_error("[Config] DRAM channels must be divisible by dram_num_partitions"); } } - parsed_config.dram_channels_per_partitions = - parsed_config.dram_channels / parsed_config.dram_num_partitions; + if (parsed_config.dram_num_partitions != 0) { + parsed_config.dram_channels_per_partitions = + parsed_config.dram_channels / parsed_config.dram_num_partitions; + } else { + parsed_config.dram_channels_per_partitions = parsed_config.dram_channels; + } /* L2D config */ - if (config.contains("l2d_type")) { - if ((std::string)config["l2d_type"] == "nocache") + if (config["l2d_type"]) { + std::string l2d_type_str = config["l2d_type"].as(); + if (l2d_type_str == "nocache") parsed_config.l2d_type = L2CacheType::NOCACHE; - else if ((std::string)config["l2d_type"] == "datacache") + else if (l2d_type_str == "datacache") { parsed_config.l2d_type = L2CacheType::DATACACHE; - else - throw std::runtime_error(fmt::format("Not implemented l2 cache type {} ", - (std::string)config["l2d_type"])); + parsed_config.l2d_config_str = get_config_value(config, "l2d_config"); + if (config["l2d_hit_latency"]) + parsed_config.l2d_hit_latency = config["l2d_hit_latency"].as(); + } else + throw std::runtime_error(fmt::format("Not implemented l2 cache type {} ", l2d_type_str)); } else { parsed_config.l2d_type = L2CacheType::NOCACHE; } - if (config.contains("l2d_config")) - parsed_config.l2d_config_str = config["l2d_config"]; - if (config.contains("l2d_hit_latency")) - parsed_config.l2d_config_str = config["l2d_hit_latency"]; - /* Icnt config */ - if ((std::string)config["icnt_type"] == "simple") + std::string icnt_type_str = config["icnt_type"].as(); + if (icnt_type_str == "simple") { parsed_config.icnt_type = IcntType::SIMPLE; - else if ((std::string)config["icnt_type"] == "booksim2") + if (config["icnt_latency_cycles"]) + parsed_config.icnt_latency = config["icnt_latency_cycles"].as(); + } else if (icnt_type_str == "booksim2") { parsed_config.icnt_type = IcntType::BOOKSIM2; - else - throw std::runtime_error(fmt::format("Not implemented icnt type {} ", - (std::string)config["icnt_type"])); - parsed_config.icnt_freq_mhz = config["icnt_freq_mhz"]; - if (config.contains("icnt_latency_cycles")) - parsed_config.icnt_latency = config["icnt_latency_cycles"]; - if (config.contains("booksim_config_path")) - parsed_config.icnt_config_path = config["booksim_config_path"]; - if (config.contains("icnt_stats_print_period_cycles")) - parsed_config.icnt_stats_print_period_cycles = config["icnt_stats_print_period_cycles"]; - if (config.contains("icnt_injection_ports_per_core")) - parsed_config.icnt_injection_ports_per_core = config["icnt_injection_ports_per_core"]; - - if (config.contains("scheduler")) - parsed_config.scheduler_type = config["scheduler"]; - if (config.contains("num_partition")) - parsed_config.num_partition = config["num_partition"]; - if (config.contains("partition")) { + parsed_config.icnt_config_path = get_config_value(config, "booksim_config_path"); + } else + throw std::runtime_error(fmt::format("Not implemented icnt type {} ", icnt_type_str)); + + parsed_config.icnt_freq_mhz = config["icnt_freq_mhz"].as(); + if (config["icnt_stats_print_period_cycles"]) + parsed_config.icnt_stats_print_period_cycles = config["icnt_stats_print_period_cycles"].as(); + if (config["icnt_injection_ports_per_core"]) + parsed_config.icnt_injection_ports_per_core = config["icnt_injection_ports_per_core"].as(); + + if (config["scheduler"]) + parsed_config.scheduler_type = config["scheduler"].as(); + if (config["num_partition"]) + parsed_config.num_partition = config["num_partition"].as(); + if (config["partition"]) { for (int i=0; i(); + parsed_config.partiton_map[i] = partition_id; + spdlog::info("[Config/Core] CPU {}: Partition {}", i, partition_id); + } else { + spdlog::warn("[Config/Core] CPU {}: Partition key not found, defaulting to 0", i); + parsed_config.partiton_map[i] = 0; + } } } else { - /* Default: all partition 0 */ for (int i=0; i +#include Core::Core(uint32_t id, SimulationConfig config) : _id(id), @@ -6,7 +9,7 @@ Core::Core(uint32_t id, SimulationConfig config) _core_cycle(0), _stat_dma_cycle(0), _num_systolic_array_per_core(config.num_systolic_array_per_core), - _dma(id, config.dram_req_size) { + _dma(id, config.dram_req_size, config.l2d_type != L2CacheType::NOCACHE) { _sa_compute_pipeline.resize(_num_systolic_array_per_core); _stat_tot_sa_compute_cycle.resize(_num_systolic_array_per_core); _stat_sa_compute_cycle.resize(_num_systolic_array_per_core); @@ -22,9 +25,9 @@ bool Core::can_issue(const std::shared_ptr& op) { } void Core::issue(std::shared_ptr op) { - if (op->get_instructions().size()){ - spdlog::trace("[{}][Core {}][TILE_SCHEDULED]", - _core_cycle, _id); + if (op->get_instructions().size()) { + core_trace_log::trace_tile_scheduled(_core_cycle, _id, + TraceLogTag::pad15(TraceLogTag::kTileScheduled)); } for (const auto& inst : op->get_instructions()) { if (inst->is_ready()) @@ -62,9 +65,9 @@ void Core::vu_cycle() { if (!_vu_compute_pipeline.empty()) { _stat_vu_compute_cycle++; if(_vu_compute_pipeline.front()->finish_cycle <= _core_cycle) { - int bubble = _vu_compute_pipeline.front()->bubble_cycle; + cycle_type bubble = _vu_compute_pipeline.front()->bubble_cycle; _stat_vu_compute_idle_cycle += bubble; - _stat_vu_compute_cycle -= bubble; + _stat_vu_compute_cycle = (bubble < _stat_vu_compute_cycle) ? (_stat_vu_compute_cycle - bubble) : 0; finish_instruction(_vu_compute_pipeline.front()); _vu_compute_pipeline.pop(); } else { @@ -83,9 +86,10 @@ void Core::sa_cycle() { while (retry) { if (!_sa_compute_pipeline.at(i).empty()) { if(_sa_compute_pipeline.at(i).front()->finish_cycle <= _core_cycle) { - int bubble = _sa_compute_pipeline.at(i).front()->bubble_cycle; + cycle_type bubble = _sa_compute_pipeline.at(i).front()->bubble_cycle; _stat_sa_compute_idle_cycle.at(i) += bubble; - _stat_sa_compute_cycle.at(i) -= bubble; + cycle_type& stat = _stat_sa_compute_cycle.at(i); + stat = (bubble < stat) ? (stat - bubble) : 0; finish_instruction(_sa_compute_pipeline.at(i).front()); _sa_compute_pipeline.at(i).pop(); } else { @@ -119,13 +123,16 @@ void Core::dma_cycle() { if (instruction->is_dma_read() && instruction->is_async_dma()) { auto& key = instruction->get_tag_id(); assert(!_dma.get_tag_finish(instruction->subgraph_id, key)); + spdlog::trace( + "[{}][Core {}] TOG async DMA response (table notify): tag_addr=0x{:016x} global_inst_id={} " + "subgraph_id={}", + _core_cycle, + _id, + static_cast(static_cast(instruction->get_addr_id())), + instruction->get_global_inst_id(), + instruction->subgraph_id); _dma.set_tag_finish(instruction->subgraph_id, key); - spdlog::trace("[{}][Core {}] {} ASYNC FINISHED, subgraph_id: {} addr_name: {} tag_id: {} tag_idx_list: {} tag_stride_list: {}", - _core_cycle, _id, opcode_to_string(instruction->get_opcode()), - instruction->subgraph_id, instruction->get_addr_name(), - fmt::format("[{}]", fmt::join(instruction->get_tag_id(), ", ")), - fmt::format("[{}]", fmt::join(instruction->get_tag_idx_list(), ", ")), - fmt::format("[{}]", fmt::join(instruction->get_tag_stride_list(), ", "))); + finish_instruction(instruction, InstFinishTraceTag::DmaRespComplete); for (auto & wait_inst : _dma.get_tag_waiter(instruction->subgraph_id, key)) { _dma.mark_tag_used(instruction->subgraph_id, key); finish_instruction(wait_inst); @@ -142,18 +149,18 @@ void Core::dma_cycle() { /* Only DMA write operation is finished! */ finish_instruction(finished_inst); } else if (finished_inst->is_dma_read() && finished_inst->is_async_dma()) { - /* Register tag table for async dma load */ - _dma.register_tag(finished_inst->subgraph_id, finished_inst->get_tag_id()); - finish_instruction(finished_inst); + /* Register tag table for async dma load; see TraceLogTag::kAsyncDmaAllRequestsIssued */ + finish_instruction(finished_inst, InstFinishTraceTag::DmaIssueComplete); } else if(!finished_inst->is_dma_read()) { - spdlog::error("[{}][Core {}] DMA instruction in not valid", _core_cycle, _id); + core_trace_log::log_error_dma_instruction_invalid(_core_cycle, _id); exit(EXIT_FAILURE); } else if (finished_inst->get_opcode() == Opcode::BAR) { - spdlog::trace("[{}][Core {}] {} FINISHED, addr_name: {} tag_id: {} tag_idx_list: {} tag_stride_list: {}", _core_cycle, _id, - opcode_to_string(finished_inst->get_opcode()), finished_inst->get_addr_name(), - fmt::format("[{}]", fmt::join(finished_inst->get_tag_id(), ", ")), - fmt::format("[{}]", fmt::join(finished_inst->get_tag_idx_list(), ", ")), - fmt::format("[{}]", fmt::join(finished_inst->get_tag_stride_list(), ", "))); + core_trace_log::trace_instruction_line(_core_cycle, + _id, + TraceLogTag::pad15(TraceLogTag::kInstructionFinished), + finished_inst->get_global_inst_id(), + core_trace_log::format_instruction_detail_line( + *finished_inst)); } /*Pass to waiting queue */ _dma_waiting_queue[finished_inst.get()] = std::move(finished_inst); @@ -222,34 +229,37 @@ void Core::cycle() { finish_instruction(inst); else _dma.register_tag_waiter(inst->subgraph_id, key, inst); - spdlog::trace("[{}][Core {}][SIKIPPED] {}, addr_name: {} tag_id: {} tag_idx_list: {} tag_stride_list: {}", _core_cycle, _id, - opcode_to_string(inst->get_opcode()), - inst->get_addr_name(), - fmt::format("[{}]", fmt::join(inst->get_tag_id(), ", ")), - fmt::format("[{}]", fmt::join(inst->get_tag_idx_list(), ", ")), - fmt::format("[{}]", fmt::join(inst->get_tag_stride_list(), ", "))); + core_trace_log::trace_instruction_line(_core_cycle, + _id, + TraceLogTag::pad15( + TraceLogTag::kInstructionSkipped), + inst->get_global_inst_id(), + core_trace_log::format_dma_inst_issued_trace_line( + *inst)); issued = true; _stat_tot_skipped_inst.at(static_cast(inst->get_opcode()))++; break; } else { - spdlog::trace("[{}][Core {}][INST_ISSUED] {}, addr_name: {} tag_id: {} tag_idx_list: {} tag_stride_list: {}", _core_cycle, _id, - opcode_to_string(inst->get_opcode()), - inst->get_addr_name(), - fmt::format("[{}]", fmt::join(inst->get_tag_id(), ", ")), - fmt::format("[{}]", fmt::join(inst->get_tag_idx_list(), ", ")), - fmt::format("[{}]", fmt::join(inst->get_tag_stride_list(), ", "))); + core_trace_log::trace_instruction_line(_core_cycle, + _id, + TraceLogTag::pad15( + TraceLogTag::kInstructionIssued), + inst->get_global_inst_id(), + core_trace_log::format_dma_inst_issued_trace_line( + *inst)); + _dma.register_tag(inst->subgraph_id, inst->get_tag_id()); _ld_inst_queue.push(inst); issued = true; break; } } case Opcode::MOVOUT: - spdlog::trace("[{}][Core {}][INST_ISSUED] {}, addr_name: {} tag_id: {} tag_idx_list: {} tag_stride_list: {}", _core_cycle, _id, - opcode_to_string(inst->get_opcode()), - inst->get_addr_name(), - fmt::format("[{}]", fmt::join(inst->get_tag_id(), ", ")), - fmt::format("[{}]", fmt::join(inst->get_tag_idx_list(), ", ")), - fmt::format("[{}]", fmt::join(inst->get_tag_stride_list(), ", "))); + core_trace_log::trace_instruction_line(_core_cycle, + _id, + TraceLogTag::pad15(TraceLogTag::kInstructionIssued), + inst->get_global_inst_id(), + core_trace_log::format_dma_inst_issued_trace_line( + *inst)); _st_inst_queue.push(inst); issued = true; break; @@ -272,8 +282,13 @@ void Core::cycle() { _stat_tot_skipped_inst.at(static_cast(inst->get_opcode()))++; instructions.erase(it); } else { - spdlog::trace("[{}][Core {}][INST_ISSUED][SA {}] {}-{}, finsh at {}", _core_cycle, _id, _systolic_array_rr, - opcode_to_string(inst->get_opcode()), inst->get_compute_type(), inst->finish_cycle); + core_trace_log::trace_instruction_line(_core_cycle, + _id, + TraceLogTag::pad15( + TraceLogTag::kInstructionIssued), + inst->get_global_inst_id(), + core_trace_log::format_instruction_detail_line( + *inst)); target_pipeline.push(inst); issued = true; if (inst->get_compute_type()) { @@ -299,16 +314,18 @@ void Core::cycle() { } else { _dma.register_tag_waiter(inst->subgraph_id, key, inst); } - spdlog::trace("[{}][Core {}][INST_ISSUED] {}, addr_name: {} tag_id: {} tag_idx_list: {} tag_stride_list: {}", _core_cycle, _id, - opcode_to_string(inst->get_opcode()), inst->get_addr_name(), - fmt::format("[{}]", fmt::join(inst->get_tag_id(), ", ")), - fmt::format("[{}]", fmt::join(inst->get_tag_idx_list(), ", ")), - fmt::format("[{}]", fmt::join(inst->get_tag_stride_list(), ", "))); + core_trace_log::trace_instruction_line(_core_cycle, + _id, + TraceLogTag::pad15( + TraceLogTag::kInstructionIssued), + inst->get_global_inst_id(), + core_trace_log::format_instruction_detail_line( + *inst)); issued = true; } break; default: - spdlog::error("Undefined instruction opcode type"); + core_trace_log::log_error_undefined_opcode(); exit(EXIT_FAILURE); } @@ -340,27 +357,34 @@ void Core::cycle() { } } -void Core::finish_instruction(std::shared_ptr& inst) { +void Core::finish_instruction(std::shared_ptr& inst, InstFinishTraceTag tag) { + if (tag == InstFinishTraceTag::DmaRespComplete) { + if (!inst->finished) { + core_trace_log::log_error_dram_responses_trace_not_finished(_core_cycle, _id); + exit(EXIT_FAILURE); + } + core_trace_log::trace_instruction_line(_core_cycle, + _id, + TraceLogTag::pad15(TraceLogTag::kAllDramResponsesReceived), + inst->get_global_inst_id(), + core_trace_log::format_instruction_detail_line(*inst)); + return; + } if (inst->finished) { - spdlog::error("[{}][Core {}][ERROR] {} inst already finished!!", _core_cycle, _id, - opcode_to_string(inst->get_opcode())); + core_trace_log::log_error_instruction_already_finished(_core_cycle, _id, + opcode_to_string(inst->get_opcode())); exit(EXIT_FAILURE); } inst->finish_instruction(); static_cast(inst->get_owner())->inc_finished_inst(); - if (inst->get_opcode() == Opcode::COMP) { - spdlog::trace("[{}][Core {}][INST_FINISHED] {}-{}", - _core_cycle, _id, opcode_to_string(inst->get_opcode()), inst->get_compute_type()); - } else if (inst->get_opcode() != Opcode::BAR && inst->is_async_dma()){ - spdlog::trace("[{}][Core {}][ASYNC] {} subgraph_id: {} addr_name: {} tag_id: {} tag_idx_list: {} tag_stride_list: {}", - _core_cycle, _id, opcode_to_string(inst->get_opcode()), inst->subgraph_id, inst->get_addr_name(), - inst->get_tag_id(), - fmt::format("[{}]", fmt::join(inst->get_tag_idx_list(), ", ")), - fmt::format("[{}]", fmt::join(inst->get_tag_stride_list(), ", "))); - } else if ((inst->get_opcode() == Opcode::MOVIN || inst->get_opcode() == Opcode::MOVOUT) && !inst->is_async_dma()) { - spdlog::trace("[{}][Core {}][INST_FINISHED] {} addr_name: {}", _core_cycle, _id, - opcode_to_string(inst->get_opcode()), inst->get_addr_name()); - } + const char* trace_tag = (tag == InstFinishTraceTag::DmaIssueComplete) + ? TraceLogTag::kAsyncDmaAllRequestsIssued + : TraceLogTag::kInstructionFinished; + core_trace_log::trace_instruction_line(_core_cycle, + _id, + TraceLogTag::pad15(trace_tag), + inst->get_global_inst_id(), + core_trace_log::format_instruction_detail_line(*inst)); } bool Core::running() { diff --git a/TOGSim/src/CoreTraceLog.cc b/TOGSim/src/CoreTraceLog.cc new file mode 100644 index 00000000..ebc31de0 --- /dev/null +++ b/TOGSim/src/CoreTraceLog.cc @@ -0,0 +1,122 @@ +#include "CoreTraceLog.h" + +#include + +#include +#include +#include + +namespace core_trace_log { + +std::string format_dma_inst_issued_detail(Instruction& inst) { + const auto& ts = inst.get_tile_size(); + const int rank = static_cast(std::max(1, ts.size())); + if (inst.get_opcode() == Opcode::MOVIN) { + return fmt::format( + "addr_name={} dram=0x{:016x} rank={} size=[{}] stride=[{}] elem_bits={} async={} indirect={} tag_id=[{}]", + inst.get_addr_name(), + static_cast(inst.get_base_dram_address()), + rank, + fmt::join(ts, ","), + fmt::join(inst.get_tile_stride(), ","), + inst.get_elem_bits(), + inst.is_async_dma(), + inst.is_indirect_mode(), + format_tag_key_list_hex(inst.get_tag_id())); + } + uint64_t tag_hex = 0; + const auto& tidx = inst.get_tag_idx_list(); + if (!tidx.empty()) { + tag_hex = static_cast(tidx[0]); + } + return fmt::format( + "addr_name={} dram=0x{:016x} rank={} elem_bits={} async={} indirect={} tag=0x{:016x} stride=[{}] size=[{}] " + "tag_idx=[{}]", + inst.get_addr_name(), + static_cast(inst.get_base_dram_address()), + rank, + inst.get_elem_bits(), + inst.is_async_dma(), + inst.is_indirect_mode(), + tag_hex, + fmt::join(inst.get_tile_stride(), ","), + fmt::join(ts, ","), + fmt::join(tidx, ",")); +} + +std::string format_dma_inst_issued_trace_line(Instruction& inst) { + return fmt::format("{} ({})", opcode_to_string(inst.get_opcode()), format_dma_inst_issued_detail(inst)); +} + +std::string format_instruction_detail_line(Instruction& inst) { + const Opcode op = inst.get_opcode(); + const std::string opname = opcode_to_string(op); + if (op == Opcode::COMP) { + return fmt::format("{} (compute_type={} compute_cycle={} overlapping_cycle={})", + opname, + inst.get_compute_type(), + inst.get_compute_cycle(), + inst.get_overlapping_cycle()); + } + if ((op == Opcode::MOVIN || op == Opcode::MOVOUT) && inst.is_async_dma()) { + return fmt::format("{} (ASYNC subgraph_id={} addr_name={} tag_id=[{}] tag_idx=[{}] tag_stride=[{}])", + opname, + inst.subgraph_id, + inst.get_addr_name(), + format_tag_key_list_hex(inst.get_tag_id()), + fmt::join(inst.get_tag_idx_list(), ","), + fmt::join(inst.get_tag_stride_list(), ",")); + } + if (op == Opcode::MOVIN || op == Opcode::MOVOUT) { + return fmt::format("{} (addr_name={})", opname, inst.get_addr_name()); + } + if (op == Opcode::BAR) { + return fmt::format("{} (addr_name={} tag_id=[{}] tag_idx=[{}] tag_stride=[{}])", + opname, + inst.get_addr_name(), + format_tag_key_list_hex(inst.get_tag_id()), + fmt::join(inst.get_tag_idx_list(), ","), + fmt::join(inst.get_tag_stride_list(), ",")); + } + return opname; +} + +void trace_tile_scheduled(cycle_type core_cycle, uint32_t core_id, const std::string& tag15) { + spdlog::trace("[{}][Core {}][{}]", core_cycle, core_id, tag15); +} + +void trace_instruction_line(cycle_type core_cycle, + uint32_t core_id, + const std::string& tag15, + uint64_t global_inst_id, + const std::string& message) { + spdlog::trace("[{}][Core {}][{}][{}={}] {}", + core_cycle, + core_id, + tag15, + TraceLogTag::kGlobalInstIdKey, + global_inst_id, + message); +} + +void log_error_dma_instruction_invalid(cycle_type core_cycle, uint32_t core_id) { + spdlog::error("[{}][Core {}] DMA instruction in not valid", core_cycle, core_id); +} + +void log_error_dram_responses_trace_not_finished(cycle_type core_cycle, uint32_t core_id) { + spdlog::error("[{}][Core {}][ERROR] ALL_DRAM_RESPONSES_RECEIVED trace but inst not finished yet", + core_cycle, + core_id); +} + +void log_error_instruction_already_finished(cycle_type core_cycle, + uint32_t core_id, + const std::string& opcode_name) { + spdlog::error("[{}][Core {}][ERROR] {} inst already finished!!", core_cycle, core_id, opcode_name); +} + +void log_error_undefined_opcode() { + spdlog::error("Undefined instruction opcode type"); +} + +} // namespace core_trace_log diff --git a/TOGSim/src/DMA.cc b/TOGSim/src/DMA.cc index f8f21025..5d509953 100644 --- a/TOGSim/src/DMA.cc +++ b/TOGSim/src/DMA.cc @@ -1,9 +1,11 @@ #include "DMA.h" #include "TileGraph.h" +#include "TraceLogTags.h" -DMA::DMA(uint32_t id, uint32_t dram_req_size) { +DMA::DMA(uint32_t id, uint32_t dram_req_size, bool l2_datacache_enabled) { _id = id; _dram_req_size = dram_req_size; + _l2_datacache_enabled = l2_datacache_enabled; _current_inst = nullptr; _finished = true; } @@ -12,7 +14,7 @@ void DMA::issue_tile(std::shared_ptr inst) { _current_inst = std::move(inst); std::vector& tile_size = _current_inst->get_tile_size(); if (tile_size.size() <= 0 || tile_size.size() > get_max_dim()) { - spdlog::error("[DMA {}] issued tile is not supported format..", _id); + spdlog::error("[DMA {}] issued tile is not supported format.. tile.size: {}, tile_size: [{}]", _id, tile_size.size(), fmt::join(tile_size, ", ")); exit(EXIT_FAILURE); } _finished = false; @@ -31,12 +33,27 @@ std::shared_ptr> DMA::get_memory_access(cycle_type core_ bool is_cacheable = owner_subgraph->is_cacheable(base_daddr, base_daddr + _dram_req_size); - spdlog::trace("[{}][Core {}][SRAM] Address: 0x{:016x}, Is_cacheable: {}", - core_cycle, _id, base_daddr, is_cacheable); - spdlog::trace("[{}][Core {}][NUMA] Subgraph id: {} , Numa id: {}, Arg: {} is_write: {}", - core_cycle, _id, owner_subgraph->get_core_id(), - _current_inst->get_numa_id(), _current_inst->get_addr_name(), - _current_inst->is_dma_write()); + if (_l2_datacache_enabled) { + spdlog::trace( + "[{}][Core {}][{}][INST_ID={}] dram=0x{:016x} cacheable={}", + core_cycle, + _id, + TraceLogTag::pad15(TraceLogTag::kL2CacheableStatusForAddress), + _current_inst->get_global_inst_id(), + base_daddr, + is_cacheable); + } + spdlog::trace( + "[{}][Core {}][{}][INST_ID={}] core_id={} subgraph_id={} numa_id={} addr_name={} is_write={}", + core_cycle, + _id, + TraceLogTag::pad15(TraceLogTag::kDmaNumaPlacement), + _current_inst->get_global_inst_id(), + owner_subgraph->get_core_id(), + _current_inst->subgraph_id, + _current_inst->get_numa_id(), + _current_inst->get_addr_name(), + _current_inst->is_dma_write()); for (const auto& addr : *addr_set) { mem_access_type acc_type = _current_inst->is_dma_write() ? mem_access_type::GLOBAL_ACC_W diff --git a/TOGSim/src/Dram.cc b/TOGSim/src/Dram.cc index 089c582e..581406f4 100644 --- a/TOGSim/src/Dram.cc +++ b/TOGSim/src/Dram.cc @@ -1,13 +1,86 @@ #include "Dram.h" +#include + +namespace { + +static bool is_power_of_2_u32(uint32_t n) { return n != 0 && (n & (n - 1)) == 0; } + +static uint32_t floor_log2_u32(uint32_t n) { + uint32_t r = 0; + while (n >>= 1) + ++r; + return r; +} + +/** Smallest power of two >= n (n >= 1). */ +static uint32_t next_power_of_2_u32(uint32_t n) { + if (n <= 1) + return 1; + --n; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + return n + 1; +} + +/** Bytes/s effective GB/s and avg-per-channel utilization % for a window of `window_cycles` DRAM ticks. */ +struct DramBwSnapshot { + double bandwidth_gbs = 0; + double util_avg_ch_pct = 0; +}; + +DramBwSnapshot make_dram_bw_snapshot(long long total_rw_transactions, uint64_t window_cycles, + uint32_t n_ch, uint32_t req_size, uint32_t n_bl, + double dram_freq_mhz) { + DramBwSnapshot out; + if (window_cycles == 0 || n_ch == 0) + return out; + const double tx = static_cast(total_rw_transactions); + const double w = static_cast(window_cycles); + const double bytes_per_cycle = tx * static_cast(req_size) / w; + out.bandwidth_gbs = bytes_per_cycle * dram_freq_mhz / 1000.0; + const double avg_per_ch = tx / static_cast(n_ch); + out.util_avg_ch_pct = avg_per_ch * 100.0 * static_cast(n_bl) / (2.0 * w); + return out; +} + +} // namespace + +new_addr_type Dram::partition_dram_address(new_addr_type raw_addr) const { + if (_req_size == 0 || _n_ch_per_partition == 0) + return raw_addr; + const new_addr_type tx = raw_addr >> _tx_log2; + const new_addr_type q = tx / _n_ch_per_partition; + return static_cast(q << _tx_log2); +} + uint32_t Dram::get_channel_id(mem_fetch* access) { - uint32_t channel_id; - if (_n_ch_per_partition >= 16) - channel_id = ipoly_hash_function((new_addr_type)access->get_addr()/_req_size, 0, _n_ch_per_partition); - else - channel_id = ipoly_hash_function((new_addr_type)access->get_addr()/_req_size, 0, 16) % _n_ch_per_partition; + uint32_t channel_in_partition = 0; + if (_n_ch_per_partition > 1) { + const new_addr_type tx = static_cast(access->get_addr() >> _tx_log2); + new_addr_type rest_high; + unsigned init_index = 0; + if (is_power_of_2_u32(_n_ch_per_partition)) { + const unsigned lb = floor_log2_u32(_n_ch_per_partition); + rest_high = tx >> lb; + init_index = static_cast(tx & (_n_ch_per_partition - 1u)); + } else { + /* gpgpu-sim "gap" channels: quotient / remainder split at txn granularity. */ + rest_high = tx / _n_ch_per_partition; + init_index = static_cast(tx % _n_ch_per_partition); + } + /* ipoly_hash_function only implements 16/32/64 (see Hashing.cc); fold like addrdec IPOLY + mod when needed. */ + const uint32_t poly_n = next_power_of_2_u32(std::max(16u, _n_ch_per_partition)); + const uint32_t poly_use = std::min(poly_n, 64u); + channel_in_partition = + static_cast(ipoly_hash_function(rest_high, init_index, poly_use)) % _n_ch_per_partition; + } - channel_id += ((access->get_numa_id() % _n_partitions)* _n_ch_per_partition); + const uint32_t channel_id = + channel_in_partition + static_cast(access->get_numa_id() % _n_partitions) * _n_ch_per_partition; return channel_id; } @@ -19,6 +92,7 @@ Dram::Dram(SimulationConfig config, cycle_type* core_cycle) { _n_partitions = config.dram_num_partitions; _n_ch_per_partition = config.dram_channels_per_partitions; _config = config; + _tx_log2 = static_cast(std::log2(_req_size)); spdlog::info("[Config/DRAM] DRAM Bandwidth {} GB/s, Freq: {} MHz, Channels: {}, Request_size: {}B", config.max_dram_bandwidth(), config.dram_freq_mhz, _n_ch, _req_size); /* Initialize DRAM Channels */ @@ -54,7 +128,8 @@ DramRamulator2::DramRamulator2(SimulationConfig config, cycle_type* core_cycle) _mem.resize(_n_ch); for (int ch = 0; ch < _n_ch; ch++) { _mem[ch] = std::make_unique( - ch, _n_ch, config.dram_config_path, "Ramulator2", _config.dram_print_interval, _n_bl); + ch, _n_ch, config.dram_config_path, "Ramulator2", _config.dram_print_interval, _n_bl, + _req_size, config.dram_freq_mhz); } _tx_log2 = log2(_req_size); _tx_ch_log2 = log2(_n_ch_per_partition) + _tx_log2; @@ -86,6 +161,30 @@ void DramRamulator2::cycle() { _mem[ch]->return_queue_pop(); } } + + if (_n_ch == 0) + return; + const int iv = _config.dram_print_interval; + if (iv <= 0) + return; + const uint64_t cc = *_core_cycles; + if (cc % static_cast(iv) != 0 || cc == 0) + return; + + const double f_mhz = static_cast(_config.dram_freq_mhz); + const uint64_t w = static_cast(iv); + for (int ch = 0; ch < _n_ch; ch++) { + const long long r = _mem[ch]->interval_reads(); + const long long wtxn = _mem[ch]->interval_writes(); + const DramBwSnapshot bw = + make_dram_bw_snapshot(r + wtxn, w, 1u, _req_size, _n_bl, f_mhz); + spdlog::info( + "[DRAM] ch {} | BW {:.2f} GB/s, {:.2f}% util | {} reads, {} writes", + ch, bw.bandwidth_gbs, bw.util_avg_ch_pct, r, wtxn); + } + for (int ch = 0; ch < _n_ch; ch++) { + _mem[ch]->reset_interval_bw_counters(); + } } void DramRamulator2::cache_cycle() { @@ -99,7 +198,8 @@ bool DramRamulator2::is_full(uint32_t cid, mem_fetch* request) { } void DramRamulator2::push(uint32_t cid, mem_fetch* request) { - addr_type target_addr = (request->get_addr() >> _tx_ch_log2) << _tx_log2; + const addr_type raw_addr = request->get_addr(); + const addr_type target_addr = partition_dram_address(raw_addr); request->set_addr(target_addr); m_from_crossbar_queue[cid].push(request); } @@ -119,9 +219,44 @@ void DramRamulator2::pop(uint32_t cid) { } void DramRamulator2::print_stat() { + spdlog::info("========= DRAM stat ========="); + if (_n_ch == 0) + return; + for (int ch = 0; ch < _n_ch; ch++) { - _mem[ch]->print(stdout); + _mem[ch]->finalize_once(); } + + spdlog::info("=== Ramulator2 stats (channels 0.. {}) ===", _n_ch - 1); + for (int ch = 0; ch < _n_ch; ch++) { + std::cout << "--- channel " << ch << " ---\n"; + _mem[ch]->print_stats_yaml(std::cout); + } + std::cout.flush(); + + const uint64_t cycles = *_core_cycles; + if (cycles == 0) + return; + const double f_mhz = static_cast(_config.dram_freq_mhz); + spdlog::info("[DRAM] per-channel avg BW ({} sim cycles):", cycles); + long long tr_all = 0; + long long tw_all = 0; + for (int ch = 0; ch < _n_ch; ch++) { + const long long tr = _mem[ch]->total_reads(); + const long long tw = _mem[ch]->total_writes(); + tr_all += tr; + tw_all += tw; + const DramBwSnapshot bw = + make_dram_bw_snapshot(tr + tw, cycles, 1u, _req_size, _n_bl, f_mhz); + spdlog::info( + "[DRAM] ch {} | avg BW {:.2f} GB/s, {:.2f}% util | {} reads, {} writes", + ch, bw.bandwidth_gbs, bw.util_avg_ch_pct, tr, tw); + } + const DramBwSnapshot bw_all = make_dram_bw_snapshot( + tr_all + tw_all, cycles, _n_ch, _req_size, _n_bl, f_mhz); + spdlog::info( + "[DRAM] all ch 0..{} | avg BW {:.2f} GB/s, {:.2f}% util (avg/ch) | {} reads, {} writes", + _n_ch - 1, bw_all.bandwidth_gbs, bw_all.util_avg_ch_pct, tr_all, tw_all); } void DramRamulator2::print_cache_stats() { @@ -137,8 +272,6 @@ SimpleDRAM::SimpleDRAM(SimulationConfig config, cycle_type* core_cycle) : Dram(c _mem.push_back(std::make_unique>("SimpleDRAM", true, -1)); } _latency = config.dram_latency; - _tx_log2 = log2(_req_size); - _tx_ch_log2 = log2(_n_ch_per_partition) + _tx_log2; } bool SimpleDRAM::running() { diff --git a/TOGSim/src/Instruction.cc b/TOGSim/src/Instruction.cc index aef9079c..f236d160 100644 --- a/TOGSim/src/Instruction.cc +++ b/TOGSim/src/Instruction.cc @@ -1,5 +1,23 @@ #include "Instruction.h" +#include + +uint64_t Instruction::_next_global_inst_id = 0; + +std::string format_tag_key_list_hex(const std::vector& tag_keys) { + if (tag_keys.empty()) { + return {}; + } + std::string out; + for (size_t i = 0; i < tag_keys.size(); ++i) { + if (i > 0) { + out.push_back(','); + } + out += fmt::format("0x{:016x}", static_cast(tag_keys[i])); + } + return out; +} + std::string opcode_to_string(Opcode opcode) { switch (opcode) { case Opcode::MOVIN: return "MOVIN"; @@ -11,13 +29,14 @@ std::string opcode_to_string(Opcode opcode) { } Instruction::Instruction(Opcode opcode, cycle_type compute_cycle, size_t num_parents, - addr_type dram_addr, std::vector tile_size, std::vector tile_stride, size_t precision, - std::vector tag_idx_list, std::vector tag_stride_list, - std::vector accum_tag_idx_list) + addr_type dram_addr, std::vector tile_size, std::vector tile_stride, size_t elem_bits, + std::vector tag_idx_list, std::vector tag_stride_list, + std::vector accum_tag_idx_list) : opcode(opcode), compute_cycle(compute_cycle), ready_counter(num_parents), dram_addr(dram_addr), - tile_size(tile_size), tile_stride(tile_stride), _precision(precision), + tile_size(tile_size), tile_stride(tile_stride), _elem_bits(elem_bits), _tag_idx_list(tag_idx_list), _tag_stride_list(tag_stride_list), _accum_tag_idx_list(accum_tag_idx_list) { + _global_inst_id = _next_global_inst_id++; assert(_tag_idx_list.size()==_tag_stride_list.size()); _tile_numel = 1; for (auto dim : tile_size) @@ -26,6 +45,7 @@ Instruction::Instruction(Opcode opcode, cycle_type compute_cycle, size_t num_par Instruction::Instruction(Opcode opcode) : opcode(opcode) { + _global_inst_id = _next_global_inst_id++; _tile_numel = 1; } @@ -51,9 +71,9 @@ void Instruction::dec_waiting_request() { void Instruction::prepare_tag_key() { /* Calculate tag key */ - int key_offset = 0; + int64_t key_offset = 0; _tag_key.push_back(_addr_id); - for (int i=0; i<_tag_idx_list.size(); i++) + for (size_t i = 0; i < _tag_idx_list.size(); i++) key_offset += _tag_idx_list.at(i) * _tag_stride_list.at(i); for (auto accum_dim : _accum_tag_idx_list) _tag_key.push_back(accum_dim); @@ -88,10 +108,10 @@ std::shared_ptr> Instruction::get_dram_address(addr_type dra dim1*tile_stride.at(tile_stride.size() - 3) + \ dim2*tile_stride.at(tile_stride.size() - 2) + \ dim3*tile_stride.at(tile_stride.size() - 1); - address = dram_addr + address * _precision; + address = dram_addr + (address * _elem_bits + 7) >> 3; if (indirect_index != NULL) { uint64_t index_val = indirect_index[index_count++]; - address += index_val * _precision; + address += (index_val * _elem_bits + 7) >> 3; } address_set->insert(address - (address & dram_req_size-1)); } diff --git a/TOGSim/src/Simulator.cc b/TOGSim/src/Simulator.cc index 857923c5..d7fe9f1b 100644 --- a/TOGSim/src/Simulator.cc +++ b/TOGSim/src/Simulator.cc @@ -121,7 +121,7 @@ void Simulator::icnt_cycle() { front->set_core_id(core_id); if (!_icnt->is_full(port_id, front)) { int node_id = _dram->get_channel_id(front) / _config.dram_channels_per_partitions; - if (core_id == node_id) + if (get_partition_id(core_id) == node_id) _cores[core_id]->inc_numa_local_access(); else _cores[core_id]->inc_numa_remote_access(); @@ -170,55 +170,8 @@ void Simulator::icnt_cycle() { _icnt->cycle(); } -int Simulator::until(cycle_type until_cycle) { - std::vector partition_scheudler_status; - for (auto &scheduler : _partition_scheduler) - partition_scheudler_status.push_back(scheduler->empty()); - - while (until_cycle == -1 || _core_cycles < until_cycle) { - set_cycle_mask(); - // Core Cycle - if (IS_CORE_CYCLE(_cycle_mask)) - core_cycle(); - - // DRAM cycle - if (IS_DRAM_CYCLE(_cycle_mask)) - dram_cycle(); - - // Interconnect cycle - if (IS_ICNT_CYCLE(_cycle_mask)) - icnt_cycle(); - - // Check if core status has changed - if (_core_cycles % 10 == 0) { - int bitmap = 0; - for (int i=0; i<_partition_scheduler.size(); i++) { - /* Skip this */ - if (partition_scheudler_status.at(i)) - continue; - - if (_partition_scheduler.at(i)->empty()) { - bitmap |= (1 << i); - } - } - if (bitmap) - return bitmap; - } - } - int bitmap = 0; - for (int i=0; i<_partition_scheduler.size(); i++) { - /* Skip this */ - if (partition_scheudler_status.at(i)) - continue; - - if (_partition_scheduler.at(i)->empty()) - bitmap |= (1ULL << i); - } - return bitmap; -} - void Simulator::cycle() { - while (running()) { + while (running() || _core_cycles < 1) { set_cycle_mask(); // Core Cycle if (IS_CORE_CYCLE(_cycle_mask)) @@ -232,7 +185,6 @@ void Simulator::cycle() { if (IS_ICNT_CYCLE(_cycle_mask)) icnt_cycle(); } - spdlog::info("Simulation finished"); for (auto &core: _cores) { core->check_tag(); } diff --git a/TOGSim/src/SparseCore.cc b/TOGSim/src/SparseCore.cc index d5629b9c..1bf1163a 100644 --- a/TOGSim/src/SparseCore.cc +++ b/TOGSim/src/SparseCore.cc @@ -1,4 +1,5 @@ #include "SparseCore.h" +#include "TraceLogTags.h" SparseCore::SparseCore(uint32_t id, SimulationConfig config) : Core(id, config) { /* Init stonne cores*/ @@ -239,7 +240,11 @@ void SparseCore::subCoreCycle(uint32_t subcore_id) { { auto acc_type = mem_access_type::GLOBAL_ACC_R; auto type = mf_type::READ_REQUEST; - spdlog::trace("[{}][StonneCore {}/{}][INST_ISSUED] {}", _core_cycle, _id, subcore_id, + spdlog::trace("[{}][StonneCore {}/{}][{}] {}", + _core_cycle, + _id, + subcore_id, + TraceLogTag::pad15(TraceLogTag::kInstructionIssued), opcode_to_string(inst->get_opcode())); for (auto addr : inst->get_trace_address()) { addr = addr - (addr & _config.dram_req_size-1); @@ -260,7 +265,11 @@ void SparseCore::subCoreCycle(uint32_t subcore_id) { { auto acc_type = mem_access_type::GLOBAL_ACC_W; auto type = mf_type::WRITE_REQUEST; - spdlog::trace("[{}][StonneCore {}/{}][INST_ISSUED] {}", _core_cycle, _id, subcore_id, + spdlog::trace("[{}][StonneCore {}/{}][{}] {}", + _core_cycle, + _id, + subcore_id, + TraceLogTag::pad15(TraceLogTag::kInstructionIssued), opcode_to_string(inst->get_opcode())); for (auto addr : inst->get_trace_address()) { addr = addr - (addr & _config.dram_req_size-1); @@ -285,8 +294,13 @@ void SparseCore::subCoreCycle(uint32_t subcore_id) { inst->finish_cycle = _core_cycle + inst->get_compute_cycle(); else inst->finish_cycle = target_pipeline.back()->finish_cycle + inst->get_compute_cycle(); - spdlog::trace("[{}][StonneCore {}/{}][INST_ISSUED] {}, finsh at {}", _core_cycle, _id, subcore_id, - opcode_to_string(inst->get_opcode()), inst->finish_cycle); + spdlog::trace("[{}][StonneCore {}/{}][{}] {}, finish_at={}", + _core_cycle, + _id, + subcore_id, + TraceLogTag::pad15(TraceLogTag::kInstructionIssued), + opcode_to_string(inst->get_opcode()), + inst->finish_cycle); target_pipeline.push(inst); issued = true; } @@ -397,7 +411,22 @@ std::shared_ptr SparseCore::pop_finished_tile() { return result; } -void SparseCore::finish_instruction(std::shared_ptr& inst) { +void SparseCore::finish_instruction(std::shared_ptr& inst, InstFinishTraceTag tag) { + if (tag == InstFinishTraceTag::DmaRespComplete) { + if (!inst->finished) { + spdlog::error("[{}][StonneCore {}][Error] ALL_DRAM_RESPONSES_RECEIVED trace but inst not finished", + _core_cycle, + _id); + exit(EXIT_FAILURE); + } + spdlog::trace("[{}][StonneCore {}][{}][INST_ID={}] {}", + _core_cycle, + _id, + TraceLogTag::pad15(TraceLogTag::kAllDramResponsesReceived), + inst->get_global_inst_id(), + opcode_to_string(inst->get_opcode())); + return; + } if (inst->finished) { spdlog::error("[{}][StonneCore {}][Error] {} inst already finished!!", _core_cycle, _id, opcode_to_string(inst->get_opcode())); @@ -405,12 +434,16 @@ void SparseCore::finish_instruction(std::shared_ptr& inst) { } inst->finish_instruction(); static_cast(inst->get_owner())->inc_finished_inst(); + const char* trace_tag = (tag == InstFinishTraceTag::DmaIssueComplete) + ? TraceLogTag::kAsyncDmaAllRequestsIssued + : TraceLogTag::kInstructionFinished; + const std::string tag15 = TraceLogTag::pad15(trace_tag); if (inst->get_opcode() == Opcode::COMP) { - spdlog::info("[{}][StonneCore {}][INST_FINISHED] {}", - _core_cycle, _id, opcode_to_string(inst->get_opcode())); + spdlog::info("[{}][StonneCore {}][{}] {}", _core_cycle, _id, tag15, + opcode_to_string(inst->get_opcode())); } else if (inst->get_opcode() == Opcode::MOVIN || inst->get_opcode() == Opcode::MOVOUT) { - spdlog::info("[{}][StonneCore {}][INST_FINISHED] {}", _core_cycle, _id, - opcode_to_string(inst->get_opcode())); + spdlog::info("[{}][StonneCore {}][{}] {}", _core_cycle, _id, tag15, + opcode_to_string(inst->get_opcode())); } } diff --git a/TOGSim/src/Tile.cc b/TOGSim/src/Tile.cc index 2e05cb08..12e4373b 100644 --- a/TOGSim/src/Tile.cc +++ b/TOGSim/src/Tile.cc @@ -26,6 +26,14 @@ void Tile::append_instuction(std::shared_ptr& inst) { } void Tile::append_child(std::shared_ptr child) { + if (!child) { + return; + } + for (const auto& existing : _child_tiles) { + if (existing == child) { + return; + } + } child->inc_ready_counter(); _child_tiles.push_back(std::move(child)); } diff --git a/TOGSim/src/TileGraph.cc b/TOGSim/src/TileGraph.cc index 120d49e2..b18e16b2 100644 --- a/TOGSim/src/TileGraph.cc +++ b/TOGSim/src/TileGraph.cc @@ -1,9 +1,32 @@ #include "TileGraph.h" +#include + +#include + int TileSubGraph::_next_id = 0; TileSubGraph::TileSubGraph() : _ready_tile_queue(), _tile_set(), _id(_next_id++) { } +void TileSubGraph::set_owner_tile_graph(std::shared_ptr g) { _owner_tile_graph = std::move(g); } + +std::shared_ptr TileSubGraph::lock_owner_tile_graph() const { return _owner_tile_graph.lock(); } + +void TileSubGraph::on_tile_issued() { _in_flight_tiles++; } + +void TileSubGraph::add_parallel_buffer(void* ptr) { + if (ptr) { + _parallel_buffers.push_back(ptr); + } +} + +void TileSubGraph::release_parallel_buffers() { + for (void* p : _parallel_buffers) { + std::free(p); + } + _parallel_buffers.clear(); +} + void TileSubGraph::add_tile(std::shared_ptr tile) { for (auto& inst : tile->get_instructions()) inst->subgraph_id = _id; @@ -15,7 +38,9 @@ void TileSubGraph::add_tile(std::shared_ptr tile) { } void TileSubGraph::finish_tile(std::shared_ptr tile) { - /* TODO. */ + if (_in_flight_tiles > 0) { + _in_flight_tiles--; + } tile->finish_tile(); for (auto child_tile_ptr: tile->get_child_tile()) { if (child_tile_ptr->get_ready_counter()) @@ -24,7 +49,9 @@ void TileSubGraph::finish_tile(std::shared_ptr tile) { _ready_tile_queue.push(child_tile_ptr); _tile_set.erase(child_tile_ptr); } - return; + if (is_finished()) { + release_parallel_buffers(); + } } const std::shared_ptr TileSubGraph::peek_tile() { @@ -51,6 +78,15 @@ void TileGraph::append_subgraph(std::shared_ptr subgraph) { _subgraph_vec.push_back(std::move(subgraph)); } +void TileGraph::wire_subgraph_owner_links() { + std::shared_ptr self = shared_from_this(); + for (const auto& sg : _subgraph_vec) { + if (sg) { + sg->set_owner_tile_graph(self); + } + } +} + bool TileGraph::is_finished() { bool finished = _subgraph_vec.empty(); /* Check all outer loop is allocated */ @@ -66,6 +102,33 @@ bool TileGraph::is_finished() { return finished; } +void TileGraph::try_emit_kernel_complete(cycle_type at_cycle, int scheduler_partition_id) { + if (_kernel_complete_logged || !is_finished()) { + return; + } + _kernel_complete_logged = true; + const unsigned int kernel_id = _kernel_id; + cycle_type start_time = _start_time; + cycle_type compute_time = 0; + if (start_time > 0) { + compute_time = at_cycle - start_time; + } else { + start_time = _arrival_time; + compute_time = at_cycle - start_time; + } + if (scheduler_partition_id >= 0) { + spdlog::info("[Scheduler {}] Kernel {} has completed (simulated) - TOG path: {} operation: {} finished at cycle {}", + scheduler_partition_id, kernel_id, _path, _name, at_cycle); + spdlog::info("[Scheduler {}] Kernel {} execution summary - Started at: {} cycles, Total compute time: {} cycles", + scheduler_partition_id, kernel_id, start_time, compute_time); + } else { + spdlog::info("Kernel {} has completed (simulated) - TOG path: {} operation: {} finished at cycle {}", + kernel_id, _path, _name, at_cycle); + spdlog::info("Kernel {} execution summary - Started at: {} cycles, Total compute time: {} cycles", + kernel_id, start_time, compute_time); + } +} + const std::shared_ptr TileGraph::peek_tile(int core_id, int slot_id) { std::shared_ptr ret = std::make_unique(Tile(Tile::Status::EMPTY)); if (_cpu_graph_map.find(core_id) == _cpu_graph_map.end()) { @@ -100,7 +163,12 @@ std::shared_ptr TileGraph::get_tile(int core_id, int slot_id) { allocate_subgraph(core_id, slot_id); return ret; } - return _cpu_graph_map[core_id][slot_id]->get_tile(); + auto& sg = _cpu_graph_map[core_id][slot_id]; + auto t = sg->get_tile(); + if (t->get_status() != Tile::Status::EMPTY) { + sg->on_tile_issued(); + } + return t; } void TileGraph::allocate_subgraph(int core_id, int slot_id) { diff --git a/TOGSim/src/TileGraphParser.cc b/TOGSim/src/TileGraphParser.cc index 42776a51..5060d336 100644 --- a/TOGSim/src/TileGraphParser.cc +++ b/TOGSim/src/TileGraphParser.cc @@ -1,18 +1,5 @@ #include "TileGraphParser.h" -bool loadConfig(const std::string& config_path, json& config_json) { - std::ifstream config_file(config_path); - if (config_file.is_open()) { - config_file >> config_json; - config_file.close(); - spdlog::info("[LoadConfig] Success to open \"{}\"", config_path); - return true; - } else { - spdlog::error("[LoadConfig] Failed to open \"{}\"", config_path); - return false; - } -} - void printIndexMap(std::string prefix, const std::map& indexMap) { std::ostringstream oss; for (const auto& [key, value] : indexMap) { @@ -87,26 +74,33 @@ bool find_output_idx(TileGraphParser* tog_parser, std::vector& output_ m = output_idx.at(0); n = output_idx.at(1); k = output_idx.at(2); + auto attr_file = tog_parser->get_attribute_file(); - auto attr_json = tog_parser->get_attribute_file(); + if (!attr_file["zero_skip"]) { + return false; + } - // Check arg0: m -> k + YAML::Node zero_skip = attr_file["zero_skip"]; bool found_arg0 = false; - if (attr_json["zero_skip"].contains("arg0")) { - auto& arg0 = attr_json["zero_skip"]["arg0"]; - if (arg0.contains(std::to_string(m)) && arg0[std::to_string(m)].contains(std::to_string(k))) { + if (zero_skip["arg0"]) { + YAML::Node arg0 = zero_skip["arg0"]; + std::string m_str = std::to_string(m); + std::string k_str = std::to_string(k); + if (arg0[m_str] && arg0[m_str][k_str]) { found_arg0 = true; } } - // Check arg1: n -> k bool found_arg1 = false; - if (attr_json["zero_skip"].contains("arg1")) { - auto& arg1 = attr_json["zero_skip"]["arg1"]; - if (arg1.contains(std::to_string(k)) && arg1[std::to_string(k)].contains(std::to_string(n))) { + if (zero_skip["arg1"]) { + YAML::Node arg1 = zero_skip["arg1"]; + std::string k_str = std::to_string(k); + std::string n_str = std::to_string(n); + if (arg1[k_str] && arg1[k_str][n_str]) { found_arg1 = true; } } + return found_arg0 || found_arg1; } @@ -198,7 +192,7 @@ TileMemoryNode::TileMemoryNode(onnx::NodeProto& node) : TileNode(node) { if (attribute.name() == "torchsim_base_addr") { _base_addr_name = attribute.s(); } else if (attribute.name() == "torchsim_element_size") { - _element_size = attribute.i(); + _elem_bits = static_cast(attribute.i()); } else if (attribute.name() == "torchsim_tile_size") { for (int i = 0; i < attribute.ints_size(); i++) _tile_size.push_back(attribute.ints(i)); @@ -210,7 +204,7 @@ TileMemoryNode::TileMemoryNode(onnx::NodeProto& node) : TileNode(node) { _tag_idx_list.push_back(attribute.strings(i)); } else if (attribute.name() == "torchsim_tag_stride_list") { for (int i = 0; i < attribute.ints_size(); i++) - _tag_stride_list.push_back(attribute.ints(i)); + _tag_stride_list.push_back(static_cast(attribute.ints(i))); } else if (attribute.name() == "torchsim_loop_idx_list") { for (int i = 0; i < attribute.strings_size(); i++) _loop_idx_list.push_back(attribute.strings(i)); @@ -232,7 +226,7 @@ void TileMemoryNode::print_node() { TileNode::print_node(); std::string spaces(get_depth(), '\t'); spdlog::debug("{} base_addr_name: {}", spaces, _base_addr_name); - spdlog::debug("{} element_size: {}", spaces, _element_size); + spdlog::debug("{} elem_bits: {}", spaces, _elem_bits); spdlog::debug("{} loop_stride_list: {} ", spaces, _loop_stride_list); spdlog::debug("{} tile_size: {} ", spaces, _tile_size); spdlog::debug("{} tile_stride: {} ", spaces, _tile_stride); @@ -249,10 +243,10 @@ TileMemoryWaitNode::TileMemoryWaitNode(onnx::NodeProto& node) : TileNode(node) { _tag_idx_list.push_back(attribute.strings(i)); } else if (attribute.name() == "torchsim_tag_stride_list") { for (int i = 0; i < attribute.ints_size(); i++) - _tag_stride_list.push_back(attribute.ints(i)); + _tag_stride_list.push_back(static_cast(attribute.ints(i))); } else if (attribute.name() == "torchsim_tag_divider_list") { for (int i = 0; i < attribute.ints_size(); i++) - _tag_divider_list.push_back(attribute.ints(i)); + _tag_divider_list.push_back(static_cast(attribute.ints(i))); } else if (attribute.name() == "torchsim_base_addr") { _base_addr_name = attribute.s(); } @@ -358,12 +352,12 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa /* Base address setting */ std::string base_addr_name = mem_node->get_base_addr_name(); - int base_addr_id = tog_parser->register_addr_name(base_addr_name); + int64_t base_addr_id = tog_parser->register_addr_name(base_addr_name); addr_type base_addr = tog_parser->lookup(base_addr_name); addr_type offset = std::inner_product(iter_list.begin(), iter_list.end(), mem_node->get_loop_stride_list().begin(), 0); - std::vector tag_list; - std::vector accum_tag_list; + std::vector tag_list; + std::vector accum_tag_list; std::vector outer_loop_idx; std::vector outer_loop_size; /* Add accumulation loop info to accum_tag list */ @@ -412,8 +406,8 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa } /* Check need to make this memory node */ - std::vector& tag_stride_list = mem_node->get_tag_stride_list(); - std::vector key = tog_parser->calc_tag(accum_tag_list, tag_list, tag_stride_list); + std::vector& tag_stride_list = mem_node->get_tag_stride_list(); + std::vector key = tog_parser->calc_tag(accum_tag_list, tag_list, tag_stride_list); if (tog_parser->check_memory_tag(base_addr_name, key)) continue; tog_parser->register_memory_tag(base_addr_name, key); @@ -428,7 +422,7 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa std::shared_ptr inst = std::make_shared( Opcode::MOVIN, 0, 0, base_addr+offset, - mem_node->get_tile_size(), mem_node->get_tile_stride(), mem_node->get_precision(), + mem_node->get_tile_size(), mem_node->get_tile_stride(), mem_node->get_elem_bits(), tag_list, tag_stride_list, accum_tag_list ); inst->set_addr_name(base_addr_name, base_addr_id); @@ -471,7 +465,7 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa /* Lookup given name's address */ std::string base_addr_name = mem_node->get_base_addr_name(); - int base_addr_id = tog_parser->register_addr_name(base_addr_name); + int64_t base_addr_id = tog_parser->register_addr_name(base_addr_name); addr_type base_addr = tog_parser->lookup(base_addr_name); addr_type offset = std::inner_product(iter_list.begin(), iter_list.end(), mem_node->get_loop_stride_list().begin(), 0); @@ -488,8 +482,8 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa std::shared_ptr inst = std::make_shared( Opcode::MOVOUT, 0, 0, base_addr+offset, - mem_node->get_tile_size(), mem_node->get_tile_stride(), mem_node->get_precision(), - std::vector(1), mem_node->get_tag_stride_list(), std::vector() + mem_node->get_tile_size(), mem_node->get_tile_stride(), mem_node->get_elem_bits(), + std::vector(1, 0), mem_node->get_tag_stride_list(), std::vector() ); inst->set_addr_name(base_addr_name, base_addr_id); inst->prepare_tag_key(); @@ -506,15 +500,15 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa printIndexMap("[TOGParser] DMA Wait Node ", iter); std::shared_ptr wait_node = std::static_pointer_cast(tile_node); auto base_addr_name = wait_node->get_base_addr_name(); - int base_addr_id = tog_parser->register_addr_name(base_addr_name); + int64_t base_addr_id = tog_parser->register_addr_name(base_addr_name); addr_type base_addr = tog_parser->lookup(base_addr_name); /* Lookup given name's address */ std::vector iter_list; - std::vector tag_list; - std::vector& tag_stride_list = wait_node->get_tag_stride_list(); - std::vector& tag_divider_list = wait_node->get_tag_divider_list(); - std::vector new_tag_stride_list; - std::vector accum_tag_list; + std::vector tag_list; + std::vector& tag_stride_list = wait_node->get_tag_stride_list(); + std::vector& tag_divider_list = wait_node->get_tag_divider_list(); + std::vector new_tag_stride_list; + std::vector accum_tag_list; auto& wait_tag_list = wait_node->get_tag_idx_list(); for (int i=0; i> TileLoopNode::get_tiles_from_iter(TileGraphPa } else if (tile_node->get_type() == TileType::COMPUTE_NODE) { printIndexMap("[TOGParser] Compute Node ", iter); std::shared_ptr compute_node = std::static_pointer_cast(tile_node); - std::vector tag_list = {0}; - std::vector tag_stride_list = {1}; - std::vector accum_tag_list; + std::vector tag_list = {0}; + std::vector tag_stride_list = {1}; + std::vector accum_tag_list; std::shared_ptr inst = std::make_shared( Opcode::COMP, compute_node->get_cycle(), 0, 0, @@ -593,9 +587,6 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa inst->add_child(child_inst); } } - /* Add instruction to tile */ - if (inst->get_opcode() == Opcode::MOVIN) - tile_vec.back()->inc_required_sram_size(inst->get_tile_numel() * inst->get_precision()); } link_map.clear(); /* iterate nested loop */ @@ -674,9 +665,6 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa inst->add_child(child_inst); } } - /* Add instruction to tile */ - if (inst->get_opcode() == Opcode::MOVIN) - tile_vec.back()->inc_required_sram_size(inst->get_tile_numel() * inst->get_precision()); } return tile_vec; @@ -691,50 +679,66 @@ void TileLoopNode::print_node() { spdlog::debug("{} stride: {} ", spaces, _stride); } -TileGraphParser::TileGraphParser(std::string onnx_path, std::string attribute_path, std::string config_path) { - loadConfig(attribute_path, _attribute_json); - loadConfig(config_path, _config_json); +TileGraphParser::TileGraphParser(std::string onnx_path, std::string attribute_path, const YAML::Node& config_yaml) { + loadConfig(attribute_path, _attribute_config); + _config_yaml = config_yaml; // Use the pre-loaded config _attribute_path = attribute_path; + if (!std::filesystem::exists(onnx_path)) { + throw std::runtime_error("Error: TOG graph path not found: " + onnx_path); + } /* Note: this parsing algorithm assume that all node are sorted in topological-order */ std::ifstream model_istream(onnx_path); google::protobuf::io::IstreamInputStream zero_copy_input(&model_istream); onnx::ModelProto model_proto; - + /* Attribute parsing */ - if (_attribute_json.contains("address_info")) { - auto address_info = _attribute_json["address_info"]; - for (auto it = address_info.begin(); it != address_info.end(); ++it) { - uint64_t value = it.value(); - _arg_to_address[it.key()] = value; - spdlog::info("[TOGParser/Attribute] Address Attribute key: {} address: 0x{:x}", it.key(), value); + if (_attribute_config["address_info"]) { + const auto& address_info = _attribute_config["address_info"]; + for (YAML::const_iterator it = address_info.begin(); it != address_info.end(); ++it) { + std::string key = it->first.as(); + uint64_t value = it->second.as(); + + _arg_to_address[key] = value; + spdlog::trace("[TOGParser/Attribute] Address Attribute key: {} address: 0x{:x}", key, value); } } - if (_attribute_json.contains("address_numa_stride")) { - auto address_numa_stride = _attribute_json["address_numa_stride"]; - for (auto it = address_numa_stride.begin(); it != address_numa_stride.end(); ++it) { - auto value_list = it.value(); - for (auto value : value_list) { - _arg_numa_stride[it.key()].push_back(value); + + if (_attribute_config["address_numa_stride"]) { + const auto& address_numa_stride = _attribute_config["address_numa_stride"]; + for (YAML::const_iterator it = address_numa_stride.begin(); it != address_numa_stride.end(); ++it) { + std::string key = it->first.as(); + const auto& value_list = it->second; // YAML Sequence Node + + for (const auto& val : value_list) { + _arg_numa_stride[key].push_back(val.as()); } - spdlog::info("[TOGParser/Attribute] Address numa info key: {} numa stride : {}", it.key(), fmt::join(_arg_numa_stride[it.key()], ", ")); + spdlog::trace("[TOGParser/Attribute] Address numa info key: {} numa stride : {}", key, fmt::join(_arg_numa_stride[key], ", ")); } } - if (_attribute_json.contains("sram_alloc") and _config_json.contains("l2d_type") and _config_json["l2d_type"] == "datacache") { - auto sram_alloc_list = _attribute_json["sram_alloc"]; + + if (_attribute_config["sram_alloc"] && + _config_yaml["l2d_type"] && + _config_yaml["l2d_type"].as() == "datacache") { + + auto sram_alloc_list = _attribute_config["sram_alloc"]; spdlog::info("[TOGParser/Attribute] ================= SRAM Alloc Plan ================"); - for (auto it = sram_alloc_list.begin(); it != sram_alloc_list.end(); ++it) { - auto value_list = it.value(); - unsigned long long start = value_list.at(0); - unsigned long long end = value_list.at(1); - spdlog::info("[TOGParser/Attribute] {:16s}: 0x{:016x} ~ 0x{:016x}", it.key(), start, end); + + for (YAML::const_iterator it = sram_alloc_list.begin(); it != sram_alloc_list.end(); ++it) { + std::string key = it->first.as(); + const auto& value_list = it->second; // List [start, end] + + unsigned long long start = value_list[0].as(); + unsigned long long end = value_list[1].as(); + + spdlog::info("[TOGParser/Attribute] {:16s}: 0x{:016x} ~ 0x{:016x}", key, start, end); Interval entry = {start, end, 0}; _cache_plan.push_back(entry); } } load_sparse_meta_data(); - /* ONNX file parsing */ + /* TOG file parsing */ _tog_path = onnx_path; model_proto.ParseFromZeroCopyStream(&zero_copy_input) && model_istream.eof(); @@ -744,7 +748,7 @@ TileGraphParser::TileGraphParser(std::string onnx_path, std::string attribute_pa /* Get meta data from graph */ for (const auto& meta : model_proto.metadata_props()) { - spdlog::info("[TOGParser] Register Metadata \"{}\": \"{}\"", meta.key(), meta.value()); + spdlog::trace("[TOGParser] Register Metadata \"{}\": \"{}\"", meta.key(), meta.value()); _tog_meta[meta.key()] = meta.value(); } @@ -835,7 +839,7 @@ TileGraphParser::TileGraphParser(std::string onnx_path, std::string attribute_pa /* Iterate outer loop and initialize inner loop */ for (auto iter=_tile_graph->begin(); iter!=_tile_graph->end(); ++iter) { std::shared_ptr subgraph = std::make_shared(); - subgraph->set_core_id(getCoreIdFromJson(_attribute_json, subgraph->get_id())); + subgraph->set_core_id(getCoreIdFromConfig(_attribute_config, subgraph->get_id())); auto indices = iter.get_indices(); for (auto loop : _loop_nodes.at(last_outer_idx)) { std::shared_ptr outer_loop = std::static_pointer_cast(loop); @@ -894,10 +898,10 @@ void TileGraphParser::register_tile(std::shared_ptr tile_node) { } } -std::vector TileGraphParser::calc_tag(std::vector& accum_tag, std::vector& tag_idx, std::vector& tag_stride) { - int key_offset = 0; - std::vector tag_key; - for (int i=0; i TileGraphParser::calc_tag(std::vector& accum_tag, std::vector& tag_idx, std::vector& tag_stride) { + int64_t key_offset = 0; + std::vector tag_key; + for (size_t i = 0; i < tag_idx.size(); i++) key_offset += tag_idx.at(i) * tag_stride.at(i); for (auto accum_dim : accum_tag) tag_key.push_back(accum_dim); @@ -905,12 +909,12 @@ std::vector TileGraphParser::calc_tag(std::vector& accum_tag, std::vec return tag_key; } -void TileGraphParser::register_memory_tag(std::string name, std::vector& tag_key) { +void TileGraphParser::register_memory_tag(std::string name, std::vector& tag_key) { assert(_tag_table.find(std::make_pair(name, tag_key))==_tag_table.end()); _tag_table[std::make_pair(name, tag_key)] = true; } -bool TileGraphParser::check_memory_tag(std::string name, std::vector& tag_key) { +bool TileGraphParser::check_memory_tag(std::string name, std::vector& tag_key) { return _tag_table.find(std::make_pair(name, tag_key))==_tag_table.end() ? false : true; } @@ -938,11 +942,12 @@ const std::vector& TileGraphParser::lookupNumaInfo(std::string key) { return _arg_numa_stride.at(key); } -int TileGraphParser::getCoreIdFromJson(const json& attribute_json, int subgraph_id) { - if (attribute_json.contains("subgraph_map")) { - const auto& subgraph_map = attribute_json["subgraph_map"]; - if (subgraph_map.contains(std::to_string(subgraph_id)) && subgraph_map[std::to_string(subgraph_id)].is_number_integer()) { - return subgraph_map[std::to_string(subgraph_id)]; +int TileGraphParser::getCoreIdFromConfig(const YAML::Node& attribute_config, int subgraph_id) { + std::string key = std::to_string(subgraph_id); + if (attribute_config["subgraph_map"]) { + const auto& subgraph_map = attribute_config["subgraph_map"]; + if (subgraph_map[key]) { + return subgraph_map[key].as(); } } return -1; diff --git a/TOGSim/src/helper/CommandLineParser.cc b/TOGSim/src/helper/CommandLineParser.cc index 66aebbe1..9cd177ac 100644 --- a/TOGSim/src/helper/CommandLineParser.cc +++ b/TOGSim/src/helper/CommandLineParser.cc @@ -12,9 +12,13 @@ void CommandLineParser::parse(int argc, char **argv) noexcept(false) { po::notify(variables_map); } +void CommandLineParser::print_help_message() const noexcept { + std::cout << options_description << std::endl; +} + void CommandLineParser::print_help_message_if_required() const noexcept { if (variables_map.count("help") > 0) { - std::cout << options_description << std::endl; + print_help_message(); exit(0); } } diff --git a/TOGSim/src/helper/CommandLineParser.h b/TOGSim/src/helper/CommandLineParser.h index 39174d5d..b41eabf3 100644 --- a/TOGSim/src/helper/CommandLineParser.h +++ b/TOGSim/src/helper/CommandLineParser.h @@ -19,7 +19,7 @@ class CommandLineParser { * Command Line Parser constructor */ CommandLineParser() noexcept { - options_description.add_options()("help", "Prints help message"); + options_description.add_options()("help,h", "Prints help message"); } /** @@ -38,6 +38,12 @@ class CommandLineParser { */ void print_help_message_if_required() const noexcept; + /** + * Prints the help message. + * (Can be called to show help for invalid options) + */ + void print_help_message() const noexcept; + /** * Add a new command line argument option. * (Should be called before `parse` method is called) diff --git a/TOGSim/src/main.cc b/TOGSim/src/main.cc index 77c1bae7..57e0e696 100644 --- a/TOGSim/src/main.cc +++ b/TOGSim/src/main.cc @@ -1,6 +1,9 @@ #include #include #include +#include +#include +#include #include "Simulator.h" #include "TileGraphParser.h" @@ -9,82 +12,83 @@ namespace fs = std::filesystem; namespace po = boost::program_options; -const char* env_value = std::getenv("TOGSIM_EAGER_MODE"); -bool isDryRun = (env_value != nullptr && std::string(env_value) == "1"); -void launchKernel(Simulator* simulator, std::string onnx_path, std::string attribute_path, std::string config_path, cycle_type request_time=0, int partiton_id=0) { - auto graph_praser = TileGraphParser(onnx_path, attribute_path, config_path); +void launchKernel(Simulator* simulator, unsigned int kernel_id, std::string onnx_path, std::string attribute_path, const YAML::Node& config_yaml, cycle_type request_time=0, int partiton_id=0, int device_id=0) { + auto graph_praser = TileGraphParser(onnx_path, attribute_path, config_yaml); std::unique_ptr& tile_graph = graph_praser.get_tile_graph(); tile_graph->set_arrival_time(request_time ? request_time : simulator->get_core_cycle()); - spdlog::info("[Scheduler {}] Register graph path: {} operation: {} at {}", partiton_id, onnx_path, tile_graph->get_name(), simulator->get_core_cycle()); - - simulator->schedule_graph(partiton_id, std::move(tile_graph)); + tile_graph->set_kernel_id(kernel_id); + spdlog::info("[Scheduler {}] Enqueued kernel id: {}, tog_path: {}, operation: {}, request_time: {}", partiton_id, kernel_id, onnx_path, tile_graph->get_name(), request_time); + simulator->enqueue_graph(partiton_id, std::move(tile_graph)); } -Simulator* create_simulator(std::string config_path) { - json config_json; - if(!loadConfig(config_path, config_json)) { - exit(1); +void process_trace_file(Simulator* simulator, std::string trace_file_path, const YAML::Node& config_yaml) { + // Open trace file (can be FIFO or regular file) + std::ifstream trace_file; + trace_file.open(trace_file_path); + if (!trace_file.is_open()) { + spdlog::error("[TOGSim] Failed to open trace file: {}", trace_file_path); + return; } - SimulationConfig config = initialize_config(config_json); - auto simulator = new Simulator(config); - return simulator; -} - -int until(Simulator *simulator, cycle_type until_cycle) { - return simulator->until(until_cycle); -} + spdlog::info("[TOGSim] Reading from trace file: {}", trace_file_path); -void interactive_mode(Simulator* simulator) { - std::string command; + // Read all available commands and process them + std::string line; + while (std::getline(trace_file, line)) { + if (line.empty()) { + continue; + } - std::cout << "[" << simulator->get_core_cycle() << "] TOGSim> "; - while (std::getline(std::cin, command)) { + // Skip comment lines starting with # + if (line[0] == '#') { + continue; + } - std::istringstream iss(command); + // Parse command: command_type,kernel_id,device_index,stream_index,tog_path,attribute_path,timestamp + std::istringstream iss(line); std::string token; - // Parse the first part of the command (e.g., "launch", "until", "quit") - iss >> token; - if (token == "launch") { - std::string onnx_path, attribute_path, config_path; - cycle_type request_time = 0; - int partition_id = 0; - iss >> config_path >> onnx_path >> attribute_path >> request_time >> partition_id; - - // Check if both paths were provided - if (onnx_path.empty() || attribute_path.empty()) { - spdlog::error("Error: Please provide both ONNX path and Attribute path in the format: launch onnx/path attribute/path"); - } else { - launchKernel(simulator, onnx_path, attribute_path, config_path, request_time, partition_id); - std::cerr << "launch done" << std::endl; - } - } else if (token == "until") { - cycle_type until_cycle; - iss >> until_cycle; - int reason; + std::vector tokens; + + while (std::getline(iss, token, ',')) { + tokens.push_back(token); + } + + if (tokens.size() != 7) { + spdlog::error("[TOGSim] Invalid command format. Expected: command_type,kernel_id,device_index,stream_index,tog_path,attribute_path,timestamp. Got: {} ({} tokens)", line, tokens.size()); + continue; + } - if (iss.fail()) { - spdlog::error("Error: Please provide a valid cycle number after 'until'"); + std::string command_type = tokens[0]; + unsigned int kernel_id = std::stoul(tokens[1]); + int device_index = std::stoi(tokens[2]); + int stream_index = std::stoi(tokens[3]); + std::string tog_path = tokens[4]; + std::string attribute_path = tokens[5]; + int timestamp = std::stoi(tokens[6]); + // timestamp (tokens[6]) is available but not used in current implementation + + try { + if (command_type == "LAUNCH_KERNEL") { + launchKernel(simulator, kernel_id, tog_path, attribute_path, config_yaml, timestamp, stream_index, device_index); + } else if (command_type == "DEVICE_SYNC") { + simulator->cycle(); + spdlog::info("[Device {}] Device synchronization completed", device_index); } else { - reason = simulator->until(until_cycle); - std::cerr << " Until finished: " << reason << std::endl; + spdlog::error("[TOGSim] Unknown command type: {}", command_type); } - } else if (token == "cycle") { - cycle_type current_cycle = simulator->get_core_cycle(); - std::cerr << "Current cycle: " << current_cycle << std::endl; - }else if (token == "quit") { - std::cerr << "Quit" << std::endl; - break; - } else { - spdlog::error("Error: unknown command {} Available commands are: launch, until, quit.", token); + } catch (const std::exception& e) { + spdlog::error("[TOGSim] Error processing command {} (type: {}): {}", kernel_id, command_type, e.what()); } - if (isDryRun) - std::cout << "[" << simulator->get_core_cycle() << "] TOGSim> "; } + trace_file.close(); simulator->cycle(); - if (simulator->get_core_cycle()==0) - simulator->until(0); - simulator->print_core_stat(); +} + +Simulator* create_simulator(const YAML::Node& config_yaml) { + SimulationConfig config = initialize_config(config_yaml); + + auto simulator = new Simulator(config); + return simulator; } int main(int argc, char** argv) { @@ -92,23 +96,32 @@ int main(int argc, char** argv) { // parse command line argumnet CommandLineParser cmd_parser = CommandLineParser(); cmd_parser.add_command_line_option( - "config", "Path for hardware configuration file"); + "config", "Path for hardware configuration file (.yml)"); cmd_parser.add_command_line_option( - "models_list", "Path for the models list file"); - cmd_parser.add_command_line_option( - "attributes_list", "Path for the models list file"); + "models_list", "Path for the trace file (.trace)"); cmd_parser.add_command_line_option( "log_level", "Set for log level [trace, debug, info], default = info"); - cmd_parser.add_command_line_option( - "mode", "choose \"trace\" moode and \"iteractive\" mode"); try { cmd_parser.parse(argc, argv); } catch (const CommandLineParser::ParsingError& e) { spdlog::error( - "Command line argument parrsing error captured. Error message: {}", + "Command line argument parsing error captured. Error message: {}", e.what()); - throw(e); + std::cerr << std::endl; + cmd_parser.print_help_message(); + exit(1); + } + + // Check if help was requested + cmd_parser.print_help_message_if_required(); + + // Dump full command for copy-paste re-run + std::ostringstream cmd_oss; + for (int i = 0; i < argc; ++i) { + if (i > 0) cmd_oss << " "; + cmd_oss << argv[i]; } + spdlog::info("[TOGSim] Run command: {}", cmd_oss.str()); std::string level = "info"; cmd_parser.set_if_defined("log_level", &level); @@ -120,29 +133,31 @@ int main(int argc, char** argv) { spdlog::set_level(spdlog::level::info); std::string config_path; - std::string onnx_path; - std::string attribute_path; - std::string execution_mode = "trace"; + std::string trace_file_path; /* Create simulator */ cmd_parser.set_if_defined("config", &config_path); - cmd_parser.set_if_defined("mode", &execution_mode); - auto simulator = create_simulator(config_path); - - if (execution_mode.compare("trace") == 0) { - /* Get needed info for launch kernel */ - cmd_parser.set_if_defined("models_list", &onnx_path); - cmd_parser.set_if_defined("attributes_list", &attribute_path); - - /* launch kernels */ - launchKernel(simulator, onnx_path, attribute_path, config_path); - simulator->run_simulator(); - if (simulator->get_core_cycle()==0) - simulator->until(1); + + // Load config once for reuse + YAML::Node config_yaml; + if (!loadConfig(config_path, config_yaml)) { + spdlog::error("[TOGSim] Failed to load config file: {}", config_path); + exit(1); + } + + auto simulator = create_simulator(config_yaml); + + // Get trace file path + cmd_parser.set_if_defined("models_list", &trace_file_path); + + if (!trace_file_path.empty()) { + // Process trace file (unified mode: supports both FIFO and regular file) + process_trace_file(simulator, trace_file_path, config_yaml); + spdlog::info("Simulation finished"); simulator->print_core_stat(); - } else if (execution_mode.compare("interactive") == 0) { - /* Get onnx_path, attribute from user input, request_time */ - interactive_mode(simulator); + } else { + spdlog::error("No trace file provided. Use --models_list to specify trace file path."); + exit(1); } delete simulator; diff --git a/TOGSim/src/scheduler/Scheduler.cc b/TOGSim/src/scheduler/Scheduler.cc index bb5d29cf..be361e7f 100644 --- a/TOGSim/src/scheduler/Scheduler.cc +++ b/TOGSim/src/scheduler/Scheduler.cc @@ -4,14 +4,32 @@ Scheduler::Scheduler(SimulationConfig config, const cycle_type* core_cycle, cons : _id(id), _config(config), _core_cycle(core_cycle), _core_time(core_time) { } -void Scheduler::schedule_graph(std::unique_ptr tile_graph) { - spdlog::info("[Scheduler {}] Tile Graph {} Scheduled", _id, "FIFO"); // TODO: tile graph id - // _tile_graph = TileGraphScheduler->get_tile_graph(); - _tile_graph.push_back(std::move(tile_graph)); +void Scheduler::enqueue_graph(std::unique_ptr tile_graph) { + std::shared_ptr sp = std::shared_ptr(std::move(tile_graph)); + sp->wire_subgraph_owner_links(); + _in_flight_graphs.insert(sp); + _tile_graph.push_back(std::move(sp)); refresh_status(); } +void Scheduler::finish_tile(std::shared_ptr tile) { + if (!tile || !tile->get_owner()) { + return; + } + auto sub = tile->get_owner(); + std::shared_ptr graph = sub->lock_owner_tile_graph(); + sub->finish_tile(tile); + if (!graph) { + return; + } + graph->try_emit_kernel_complete(*_core_cycle, _id); + if (graph->kernel_complete_logged()) { + _in_flight_graphs.erase(graph); + } +} + const std::shared_ptr Scheduler::peek_tile(int core_id, int slot_id, CoreType ctype) { + refresh_status(); if (_tile_graph.empty() || _tile_graph.at(0)->get_arrival_time() > *_core_cycle) return std::make_unique(Tile(Tile::Status::EMPTY)); if ((!_tile_graph.at(0)->StonneGraph && ctype == CoreType::WS_MESH) || (_tile_graph.at(0)->StonneGraph && ctype == CoreType::STONNE)) @@ -25,6 +43,10 @@ std::shared_ptr Scheduler::get_tile(int core_id, int slot_id) { return tile; } else { tile = std::move(_tile_graph.at(0)->get_tile(core_id, slot_id)); + // Record start_time when first non-EMPTY tile is issued + if (tile->get_status() != Tile::Status::EMPTY && _tile_graph.at(0)->get_start_time() == 0) { + _tile_graph.at(0)->set_start_time(*_core_cycle); + } } refresh_status(); return tile; @@ -46,13 +68,7 @@ void Scheduler::refresh_status() { if (_tile_graph.empty()) return; - /* Remove finished request */ if (_tile_graph.at(0)->is_finished()) { - spdlog::info("[Scheduler {}] Graph path: {} operation: {} finish at {}", - _id, _tile_graph.at(0)->get_graph_path(), - _tile_graph.at(0)->get_name(), *_core_cycle); - spdlog::info("Total compute time {}", - *_core_cycle - _tile_graph.at(0)->get_arrival_time()); _tile_graph.pop_front(); } } \ No newline at end of file diff --git a/configs/heterogeneous_c2_simple_noc.json b/configs/heterogeneous_c2_simple_noc.json deleted file mode 100644 index a68f38c2..00000000 --- a/configs/heterogeneous_c2_simple_noc.json +++ /dev/null @@ -1,40 +0,0 @@ -{ - "core_type" : ["stonne", "ws_mesh"], - "stonne_config_path" : "/workspace/PyTorchSim/TOGSim/extern/stonneCore/tests/sparseflex_op_128mses_128_bw.cfg", - "num_cores" : 2, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - - "num_stonne_per_core" : 8, - "num_stonne_port" : 64, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 16, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - - "num_partition" : 2, - "partition": { - "core_0":0, - "core_1":1 - }, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/heterogeneous_c2_simple_noc.yml b/configs/heterogeneous_c2_simple_noc.yml new file mode 100644 index 00000000..9c596d85 --- /dev/null +++ b/configs/heterogeneous_c2_simple_noc.yml @@ -0,0 +1,37 @@ +core_type: +- stonne +- ws_mesh +stonne_config_path: /workspace/PyTorchSim/TOGSim/extern/stonneCore/tests/sparseflex_op_128mses_128_bw.cfg +num_cores: 2 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_stonne_per_core: 8 +num_stonne_port: 64 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 16 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 +num_partition: 2 +partition: + core_0: 0 + core_1: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/ramulator2_configs/DDR4.yaml b/configs/ramulator2_configs/DDR4.yaml index e65528ed..45799436 100644 --- a/configs/ramulator2_configs/DDR4.yaml +++ b/configs/ramulator2_configs/DDR4.yaml @@ -1,25 +1,421 @@ -Frontend: - impl: GEM5 - -MemorySystem: - impl: GenericDRAM - clock_ratio: 1 - - DRAM: - impl: DDR4 - org: - preset: DDR4_16Gb_x4 - channel: 1 - timing: - preset: DDR4_1600J - - Controller: - impl: Generic - Scheduler: - impl: FRFCFS - RefreshManager: - impl: AllBank - plugins: - - AddrMapper: - impl: RoBaRaCoCh \ No newline at end of file +{ + "frontend": { + "impl": "External", + "clock_ratio": 1 + }, + "memory_system": { + "impl": "GenericDRAM", + "clock_ratio": 1, + "channel_mapper": { + "impl": "PassThroughChannelMapper" + }, + "controllers": [ + { + "impl": "GenericDDR", + "wr_low_watermark": 0.2, + "wr_high_watermark": 0.8, + "read_buffer_size": 32, + "write_buffer_size": 32, + "priority_buffer_size": 1568, + "scheduler": { + "impl": "FRFCFS" + }, + "refresh_manager": { + "impl": "AllBank", + "scope": "Rank" + }, + "row_policy": { + "impl": "Open" + }, + "addr_mapper": { + "impl": "RoBaRaCoCh" + }, + "dram": { + "impl": "DDR4", + "org": { + "dq": 8, + "count": [ + 1, + 1, + 4, + 4, + 65536, + 1024 + ] + }, + "timing": [ + 3200, + 4, + 22, + 22, + 22, + 52, + 74, + 24, + 12, + 16, + 4, + 8, + 4, + 8, + 4, + 12, + 34, + 576, + 12480, + 2, + 625 + ], + "channel_width": 64, + "read_latency": 26, + "timing_constraints": [ + [ + 0, + [ + 3, + 5 + ], + [ + 3, + 5 + ], + 4 + ], + [ + 0, + [ + 4, + 6 + ], + [ + 4, + 6 + ], + 4 + ], + [ + 1, + [ + 3, + 5 + ], + [ + 3, + 5 + ], + 4 + ], + [ + 1, + [ + 4, + 6 + ], + [ + 4, + 6 + ], + 4 + ], + [ + 1, + [ + 3, + 5 + ], + [ + 4, + 6 + ], + 12 + ], + [ + 1, + [ + 4, + 6 + ], + [ + 3, + 5 + ], + 24 + ], + [ + 1, + [ + 3, + 5 + ], + [ + 3, + 4, + 5, + 6 + ], + 6, + 1, + true + ], + [ + 1, + [ + 4, + 6 + ], + [ + 3, + 5 + ], + 0, + 1, + true + ], + [ + 1, + [ + 3 + ], + [ + 2 + ], + 12 + ], + [ + 1, + [ + 4 + ], + [ + 2 + ], + 44 + ], + [ + 1, + [ + 0 + ], + [ + 0 + ], + 4 + ], + [ + 1, + [ + 0 + ], + [ + 0 + ], + 34, + 4 + ], + [ + 1, + [ + 0 + ], + [ + 2 + ], + 52 + ], + [ + 1, + [ + 2 + ], + [ + 0 + ], + 22 + ], + [ + 1, + [ + 0 + ], + [ + 7 + ], + 74 + ], + [ + 1, + [ + 1, + 2 + ], + [ + 7 + ], + 22 + ], + [ + 1, + [ + 5 + ], + [ + 7 + ], + 34 + ], + [ + 1, + [ + 6 + ], + [ + 7 + ], + 66 + ], + [ + 1, + [ + 7 + ], + [ + 0, + 2 + ], + 576 + ], + [ + 2, + [ + 3, + 5 + ], + [ + 3, + 5 + ], + 8 + ], + [ + 2, + [ + 4, + 6 + ], + [ + 4, + 6 + ], + 8 + ], + [ + 2, + [ + 4, + 6 + ], + [ + 3, + 5 + ], + 32 + ], + [ + 2, + [ + 0 + ], + [ + 0 + ], + 8 + ], + [ + 3, + [ + 0 + ], + [ + 0 + ], + 74 + ], + [ + 3, + [ + 0 + ], + [ + 3, + 4, + 5, + 6 + ], + 22 + ], + [ + 3, + [ + 0 + ], + [ + 1 + ], + 52 + ], + [ + 3, + [ + 1 + ], + [ + 0 + ], + 22 + ], + [ + 3, + [ + 3 + ], + [ + 1 + ], + 12 + ], + [ + 3, + [ + 4 + ], + [ + 1 + ], + 44 + ], + [ + 3, + [ + 5 + ], + [ + 0 + ], + 34 + ], + [ + 3, + [ + 6 + ], + [ + 0 + ], + 66 + ] + ] + } + } + ] + } +} \ No newline at end of file diff --git a/configs/ramulator2_configs/HBM2.yaml b/configs/ramulator2_configs/HBM2.yaml index 70cddef0..2bdd1705 100644 --- a/configs/ramulator2_configs/HBM2.yaml +++ b/configs/ramulator2_configs/HBM2.yaml @@ -1,25 +1,476 @@ -Frontend: - impl: GEM5 - -MemorySystem: - impl: GenericDRAM - clock_ratio: 1 - - DRAM: - impl: HBM2 - org: - preset: HBM2_8Gb - channel: 1 - timing: - preset: HBM2_1.4Gbps - - Controller: - impl: Generic - Scheduler: - impl: FRFCFS - RefreshManager: - impl: AllBank - plugins: - - AddrMapper: - impl: RoBaRaCoCh \ No newline at end of file +{ + "frontend": { + "impl": "External", + "clock_ratio": 1 + }, + "memory_system": { + "impl": "GenericDRAM", + "clock_ratio": 1, + "channel_mapper": { + "impl": "PassThroughChannelMapper" + }, + "controllers": [ + { + "impl": "GenericDDR", + "wr_low_watermark": 0.2, + "wr_high_watermark": 0.8, + "read_buffer_size": 32, + "write_buffer_size": 32, + "priority_buffer_size": 1568, + "scheduler": { + "impl": "FRFCFS" + }, + "refresh_manager": { + "impl": "AllBank", + "scope": "PseudoChannel" + }, + "row_policy": { + "impl": "Open" + }, + "addr_mapper": { + "impl": "RoBaRaCoCh" + }, + "dram": { + "impl": "HBM2", + "org": { + "dq": 64, + "count": [ + 1, + 2, + 4, + 4, + 65536, + 32 + ] + }, + "timing": [ + 2000, + 2, + 14, + 14, + 12, + 14, + 34, + 48, + 16, + 5, + 5, + 2, + 4, + 4, + 4, + 6, + 8, + 15, + 350, + 160, + 8, + 3900, + 122, + 1000 + ], + "channel_width": 64, + "read_latency": 16, + "timing_constraints": [ + [ + 0, + [ + 0 + ], + [ + 0, + 1, + 2, + 7, + 8 + ], + 2 + ], + [ + 1, + [ + 3, + 5 + ], + [ + 3, + 5 + ], + 2 + ], + [ + 1, + [ + 4, + 6 + ], + [ + 4, + 6 + ], + 2 + ], + [ + 1, + [ + 3, + 5 + ], + [ + 3, + 5 + ], + 2 + ], + [ + 1, + [ + 4, + 6 + ], + [ + 4, + 6 + ], + 2 + ], + [ + 1, + [ + 3, + 5 + ], + [ + 4, + 6 + ], + 13 + ], + [ + 1, + [ + 4, + 6 + ], + [ + 3, + 5 + ], + 13 + ], + [ + 1, + [ + 3 + ], + [ + 2 + ], + 5 + ], + [ + 1, + [ + 4 + ], + [ + 2 + ], + 23 + ], + [ + 1, + [ + 0 + ], + [ + 0 + ], + 4 + ], + [ + 1, + [ + 0 + ], + [ + 0 + ], + 15, + 4 + ], + [ + 1, + [ + 0 + ], + [ + 2 + ], + 35 + ], + [ + 1, + [ + 2 + ], + [ + 0 + ], + 13 + ], + [ + 1, + [ + 0 + ], + [ + 7 + ], + 49 + ], + [ + 1, + [ + 1, + 2 + ], + [ + 7 + ], + 14 + ], + [ + 1, + [ + 5 + ], + [ + 7 + ], + 19 + ], + [ + 1, + [ + 6 + ], + [ + 7 + ], + 37 + ], + [ + 1, + [ + 7 + ], + [ + 0 + ], + 349 + ], + [ + 1, + [ + 7 + ], + [ + 2 + ], + 350 + ], + [ + 1, + [ + 8 + ], + [ + 0 + ], + 7 + ], + [ + 1, + [ + 0 + ], + [ + 8 + ], + 5 + ], + [ + 2, + [ + 3, + 5 + ], + [ + 3, + 5 + ], + 4 + ], + [ + 2, + [ + 4, + 6 + ], + [ + 4, + 6 + ], + 4 + ], + [ + 2, + [ + 4, + 6 + ], + [ + 3, + 5 + ], + 15 + ], + [ + 2, + [ + 0 + ], + [ + 0 + ], + 4 + ], + [ + 3, + [ + 0 + ], + [ + 0 + ], + 48 + ], + [ + 3, + [ + 0 + ], + [ + 3, + 5 + ], + 15 + ], + [ + 3, + [ + 0 + ], + [ + 4, + 6 + ], + 13 + ], + [ + 3, + [ + 0 + ], + [ + 1 + ], + 35 + ], + [ + 3, + [ + 1 + ], + [ + 0 + ], + 13 + ], + [ + 3, + [ + 3 + ], + [ + 1 + ], + 5 + ], + [ + 3, + [ + 4 + ], + [ + 1 + ], + 23 + ], + [ + 3, + [ + 5 + ], + [ + 0 + ], + 18 + ], + [ + 3, + [ + 6 + ], + [ + 0 + ], + 36 + ], + [ + 3, + [ + 8 + ], + [ + 0 + ], + 159 + ], + [ + 3, + [ + 0 + ], + [ + 8 + ], + 49 + ], + [ + 3, + [ + 1 + ], + [ + 8 + ], + 14 + ] + ] + } + } + ] + } +} \ No newline at end of file diff --git a/configs/ramulator2_configs/HBM2_TPUv3.yaml b/configs/ramulator2_configs/HBM2_TPUv3.yaml index e6543d14..01cab613 100644 --- a/configs/ramulator2_configs/HBM2_TPUv3.yaml +++ b/configs/ramulator2_configs/HBM2_TPUv3.yaml @@ -1,25 +1,476 @@ -Frontend: - impl: GEM5 - -MemorySystem: - impl: GenericDRAM - clock_ratio: 1 - - DRAM: - impl: HBM2 - org: - preset: HBM2_8Gb - channel: 1 - timing: - preset: HBM2_1.8Gbps - - Controller: - impl: Generic - Scheduler: - impl: FRFCFS - RefreshManager: - impl: AllBank - plugins: - - AddrMapper: - impl: RoBaRaCoCh \ No newline at end of file +{ + "frontend": { + "impl": "External", + "clock_ratio": 1 + }, + "memory_system": { + "impl": "GenericDRAM", + "clock_ratio": 1, + "channel_mapper": { + "impl": "PassThroughChannelMapper" + }, + "controllers": [ + { + "impl": "HBM", + "wr_low_watermark": 0.2, + "wr_high_watermark": 0.8, + "read_buffer_size": 64, + "write_buffer_size": 64, + "priority_buffer_size": 1568, + "scheduler": { + "impl": "FRFCFS" + }, + "refresh_manager": { + "impl": "AllBank", + "scope": "PseudoChannel" + }, + "row_policy": { + "impl": "Open" + }, + "addr_mapper": { + "impl": "RoBaRaCoCh" + }, + "dram": { + "impl": "HBM2", + "org": { + "dq": 64, + "count": [ + 1, + 2, + 4, + 4, + 65536, + 32 + ] + }, + "timing": [ + 2000, + 2, + 14, + 14, + 12, + 14, + 34, + 48, + 16, + 5, + 5, + 2, + 4, + 4, + 4, + 6, + 8, + 15, + 350, + 160, + 8, + 3900, + 122, + 1000 + ], + "channel_width": 64, + "read_latency": 16, + "timing_constraints": [ + [ + 0, + [ + 0 + ], + [ + 0, + 1, + 2, + 7, + 8 + ], + 2 + ], + [ + 1, + [ + 3, + 5 + ], + [ + 3, + 5 + ], + 2 + ], + [ + 1, + [ + 4, + 6 + ], + [ + 4, + 6 + ], + 2 + ], + [ + 1, + [ + 3, + 5 + ], + [ + 3, + 5 + ], + 2 + ], + [ + 1, + [ + 4, + 6 + ], + [ + 4, + 6 + ], + 2 + ], + [ + 1, + [ + 3, + 5 + ], + [ + 4, + 6 + ], + 13 + ], + [ + 1, + [ + 4, + 6 + ], + [ + 3, + 5 + ], + 13 + ], + [ + 1, + [ + 3 + ], + [ + 2 + ], + 5 + ], + [ + 1, + [ + 4 + ], + [ + 2 + ], + 23 + ], + [ + 1, + [ + 0 + ], + [ + 0 + ], + 4 + ], + [ + 1, + [ + 0 + ], + [ + 0 + ], + 15, + 4 + ], + [ + 1, + [ + 0 + ], + [ + 2 + ], + 35 + ], + [ + 1, + [ + 2 + ], + [ + 0 + ], + 13 + ], + [ + 1, + [ + 0 + ], + [ + 7 + ], + 49 + ], + [ + 1, + [ + 1, + 2 + ], + [ + 7 + ], + 14 + ], + [ + 1, + [ + 5 + ], + [ + 7 + ], + 19 + ], + [ + 1, + [ + 6 + ], + [ + 7 + ], + 37 + ], + [ + 1, + [ + 7 + ], + [ + 0 + ], + 349 + ], + [ + 1, + [ + 7 + ], + [ + 2 + ], + 350 + ], + [ + 1, + [ + 8 + ], + [ + 0 + ], + 7 + ], + [ + 1, + [ + 0 + ], + [ + 8 + ], + 5 + ], + [ + 2, + [ + 3, + 5 + ], + [ + 3, + 5 + ], + 4 + ], + [ + 2, + [ + 4, + 6 + ], + [ + 4, + 6 + ], + 4 + ], + [ + 2, + [ + 4, + 6 + ], + [ + 3, + 5 + ], + 15 + ], + [ + 2, + [ + 0 + ], + [ + 0 + ], + 4 + ], + [ + 3, + [ + 0 + ], + [ + 0 + ], + 48 + ], + [ + 3, + [ + 0 + ], + [ + 3, + 5 + ], + 15 + ], + [ + 3, + [ + 0 + ], + [ + 4, + 6 + ], + 13 + ], + [ + 3, + [ + 0 + ], + [ + 1 + ], + 35 + ], + [ + 3, + [ + 1 + ], + [ + 0 + ], + 13 + ], + [ + 3, + [ + 3 + ], + [ + 1 + ], + 5 + ], + [ + 3, + [ + 4 + ], + [ + 1 + ], + 23 + ], + [ + 3, + [ + 5 + ], + [ + 0 + ], + 18 + ], + [ + 3, + [ + 6 + ], + [ + 0 + ], + 36 + ], + [ + 3, + [ + 8 + ], + [ + 0 + ], + 159 + ], + [ + 3, + [ + 0 + ], + [ + 8 + ], + 49 + ], + [ + 3, + [ + 1 + ], + [ + 8 + ], + 14 + ] + ] + } + } + ] + } +} \ No newline at end of file diff --git a/configs/ramulator2_configs/LPDDR5.yaml b/configs/ramulator2_configs/LPDDR5.yaml new file mode 100644 index 00000000..bf039f9f --- /dev/null +++ b/configs/ramulator2_configs/LPDDR5.yaml @@ -0,0 +1,494 @@ +{ + "frontend": { + "impl": "External", + "clock_ratio": 1 + }, + "memory_system": { + "impl": "GenericDRAM", + "clock_ratio": 1, + "channel_mapper": { + "impl": "PassThroughChannelMapper" + }, + "controllers": [ + { + "impl": "GenericDDR", + "wr_low_watermark": 0.2, + "wr_high_watermark": 0.8, + "read_buffer_size": 32, + "write_buffer_size": 32, + "priority_buffer_size": 1568, + "scheduler": { + "impl": "FRFCFS" + }, + "refresh_manager": { + "impl": "AllBank", + "scope": "Rank" + }, + "row_policy": { + "impl": "Open" + }, + "addr_mapper": { + "impl": "RoBaRaCoCh" + }, + "dram": { + "impl": "LPDDR5", + "org": { + "dq": 16, + "count": [ + 1, + 1, + 4, + 4, + 32768, + 1024 + ] + }, + "timing": [ + 6400, + 2, + 17, + 15, + 15, + 17, + 34, + 49, + 28, + 8, + 9, + 2, + 2, + 4, + 2, + 4, + 4, + 4, + 5, + 10, + 16, + 168, + 96, + 3125, + 391, + 1, + 0, + 8, + 2, + 1250 + ], + "channel_width": 16, + "read_latency": 19, + "timing_constraints": [ + [ + 0, + [ + 6, + 8 + ], + [ + 6, + 8 + ], + 2 + ], + [ + 0, + [ + 7, + 9 + ], + [ + 7, + 9 + ], + 2 + ], + [ + 3, + [ + 4 + ], + [ + 6, + 8 + ], + 0 + ], + [ + 3, + [ + 5 + ], + [ + 7, + 9 + ], + 0 + ], + [ + 1, + [ + 6, + 8 + ], + [ + 6, + 8 + ], + 2 + ], + [ + 1, + [ + 7, + 9 + ], + [ + 7, + 9 + ], + 2 + ], + [ + 1, + [ + 6, + 8 + ], + [ + 7, + 9 + ], + 12 + ], + [ + 1, + [ + 7, + 9 + ], + [ + 6, + 8 + ], + 16 + ], + [ + 1, + [ + 6, + 8 + ], + [ + 6, + 7, + 8, + 9 + ], + 4, + 1, + true + ], + [ + 1, + [ + 7, + 9 + ], + [ + 6, + 8 + ], + 12, + 1, + true + ], + [ + 1, + [ + 6 + ], + [ + 3 + ], + 8 + ], + [ + 1, + [ + 7 + ], + [ + 3 + ], + 39 + ], + [ + 1, + [ + 0 + ], + [ + 0 + ], + 4 + ], + [ + 1, + [ + 0 + ], + [ + 0 + ], + 16, + 4 + ], + [ + 1, + [ + 0 + ], + [ + 3 + ], + 34 + ], + [ + 1, + [ + 3 + ], + [ + 0 + ], + 17 + ], + [ + 1, + [ + 2, + 3 + ], + [ + 2, + 3 + ], + 2 + ], + [ + 1, + [ + 0 + ], + [ + 10 + ], + 49 + ], + [ + 1, + [ + 2, + 3 + ], + [ + 10 + ], + 15 + ], + [ + 1, + [ + 8 + ], + [ + 10 + ], + 23 + ], + [ + 1, + [ + 9 + ], + [ + 10 + ], + 54 + ], + [ + 1, + [ + 10 + ], + [ + 0, + 3 + ], + 168 + ], + [ + 2, + [ + 6, + 8 + ], + [ + 6, + 8 + ], + 4 + ], + [ + 2, + [ + 7, + 9 + ], + [ + 7, + 9 + ], + 4 + ], + [ + 2, + [ + 7, + 9 + ], + [ + 6, + 8 + ], + 21 + ], + [ + 2, + [ + 0 + ], + [ + 0 + ], + 4 + ], + [ + 3, + [ + 0 + ], + [ + 0 + ], + 49 + ], + [ + 3, + [ + 0 + ], + [ + 6, + 7, + 8, + 9 + ], + 15 + ], + [ + 3, + [ + 0 + ], + [ + 2 + ], + 34 + ], + [ + 3, + [ + 2 + ], + [ + 0 + ], + 15 + ], + [ + 3, + [ + 6 + ], + [ + 2 + ], + 8 + ], + [ + 3, + [ + 7 + ], + [ + 2 + ], + 39 + ], + [ + 3, + [ + 8 + ], + [ + 0 + ], + 23 + ], + [ + 3, + [ + 9 + ], + [ + 0 + ], + 54 + ], + [ + 3, + [ + 11 + ], + [ + 0 + ], + 96 + ], + [ + 3, + [ + 0 + ], + [ + 11 + ], + 49 + ], + [ + 3, + [ + 2 + ], + [ + 11 + ], + 15 + ] + ] + } + } + ] + } +} \ No newline at end of file diff --git a/configs/ramulator2_configs/LPDDR5X.yaml b/configs/ramulator2_configs/LPDDR5X.yaml new file mode 100644 index 00000000..4309aa6c --- /dev/null +++ b/configs/ramulator2_configs/LPDDR5X.yaml @@ -0,0 +1,494 @@ +{ + "frontend": { + "impl": "External", + "clock_ratio": 1 + }, + "memory_system": { + "impl": "GenericDRAM", + "clock_ratio": 1, + "channel_mapper": { + "impl": "PassThroughChannelMapper" + }, + "controllers": [ + { + "impl": "GenericDDR", + "wr_low_watermark": 0.2, + "wr_high_watermark": 0.8, + "read_buffer_size": 32, + "write_buffer_size": 32, + "priority_buffer_size": 1568, + "scheduler": { + "impl": "FRFCFS" + }, + "refresh_manager": { + "impl": "AllBank", + "scope": "Rank" + }, + "row_policy": { + "impl": "Open" + }, + "addr_mapper": { + "impl": "RoBaRaCoCh" + }, + "dram": { + "impl": "LPDDR5", + "org": { + "dq": 16, + "count": [ + 1, + 1, + 4, + 4, + 32768, + 1024 + ] + }, + "timing": [ + 8533, + 2, + 23, + 20, + 20, + 23, + 46, + 65, + 38, + 11, + 12, + 2, + 2, + 4, + 2, + 4, + 6, + 6, + 7, + 14, + 22, + 224, + 128, + 4165, + 521, + 1, + 0, + 8, + 2, + 938 + ], + "channel_width": 16, + "read_latency": 25, + "timing_constraints": [ + [ + 0, + [ + 6, + 8 + ], + [ + 6, + 8 + ], + 2 + ], + [ + 0, + [ + 7, + 9 + ], + [ + 7, + 9 + ], + 2 + ], + [ + 3, + [ + 4 + ], + [ + 6, + 8 + ], + 0 + ], + [ + 3, + [ + 5 + ], + [ + 7, + 9 + ], + 0 + ], + [ + 1, + [ + 6, + 8 + ], + [ + 6, + 8 + ], + 2 + ], + [ + 1, + [ + 7, + 9 + ], + [ + 7, + 9 + ], + 2 + ], + [ + 1, + [ + 6, + 8 + ], + [ + 7, + 9 + ], + 15 + ], + [ + 1, + [ + 7, + 9 + ], + [ + 6, + 8 + ], + 21 + ], + [ + 1, + [ + 6, + 8 + ], + [ + 6, + 7, + 8, + 9 + ], + 4, + 1, + true + ], + [ + 1, + [ + 7, + 9 + ], + [ + 6, + 8 + ], + 15, + 1, + true + ], + [ + 1, + [ + 6 + ], + [ + 3 + ], + 11 + ], + [ + 1, + [ + 7 + ], + [ + 3 + ], + 52 + ], + [ + 1, + [ + 0 + ], + [ + 0 + ], + 6 + ], + [ + 1, + [ + 0 + ], + [ + 0 + ], + 22, + 4 + ], + [ + 1, + [ + 0 + ], + [ + 3 + ], + 46 + ], + [ + 1, + [ + 3 + ], + [ + 0 + ], + 23 + ], + [ + 1, + [ + 2, + 3 + ], + [ + 2, + 3 + ], + 2 + ], + [ + 1, + [ + 0 + ], + [ + 10 + ], + 65 + ], + [ + 1, + [ + 2, + 3 + ], + [ + 10 + ], + 20 + ], + [ + 1, + [ + 8 + ], + [ + 10 + ], + 31 + ], + [ + 1, + [ + 9 + ], + [ + 10 + ], + 72 + ], + [ + 1, + [ + 10 + ], + [ + 0, + 3 + ], + 224 + ], + [ + 2, + [ + 6, + 8 + ], + [ + 6, + 8 + ], + 4 + ], + [ + 2, + [ + 7, + 9 + ], + [ + 7, + 9 + ], + 4 + ], + [ + 2, + [ + 7, + 9 + ], + [ + 6, + 8 + ], + 28 + ], + [ + 2, + [ + 0 + ], + [ + 0 + ], + 6 + ], + [ + 3, + [ + 0 + ], + [ + 0 + ], + 65 + ], + [ + 3, + [ + 0 + ], + [ + 6, + 7, + 8, + 9 + ], + 20 + ], + [ + 3, + [ + 0 + ], + [ + 2 + ], + 46 + ], + [ + 3, + [ + 2 + ], + [ + 0 + ], + 20 + ], + [ + 3, + [ + 6 + ], + [ + 2 + ], + 11 + ], + [ + 3, + [ + 7 + ], + [ + 2 + ], + 52 + ], + [ + 3, + [ + 8 + ], + [ + 0 + ], + 31 + ], + [ + 3, + [ + 9 + ], + [ + 0 + ], + 72 + ], + [ + 3, + [ + 11 + ], + [ + 0 + ], + 128 + ], + [ + 3, + [ + 0 + ], + [ + 11 + ], + 65 + ], + [ + 3, + [ + 2 + ], + [ + 11 + ], + 20 + ] + ] + } + } + ] + } +} \ No newline at end of file diff --git a/configs/ramulator2_configs/gen_configs.py b/configs/ramulator2_configs/gen_configs.py new file mode 100644 index 00000000..64eb62d2 --- /dev/null +++ b/configs/ramulator2_configs/gen_configs.py @@ -0,0 +1,109 @@ +""" +Generate machine-readable ramulator2 v2.1 config files for PyTorchSim. + +Usage: + python gen_configs.py + +Each function generates a JSON config that C++ can load directly via +Config::parse_config_file(). No preset resolution happens in C++ anymore. +""" + +import json +import sys +import os + +# Add ramulator2 Python DSL to path +RAMULATOR_PYTHON = os.path.join(os.path.dirname(__file__), + "../../TOGSim/extern/ramulator2/python") +sys.path.insert(0, RAMULATOR_PYTHON) + +import ramulator +import ramulator.dram +import ramulator.controller +import ramulator.scheduler +import ramulator.refresh_manager +import ramulator.row_policy +import ramulator.addr_mapper +import ramulator.channel_mapper +import ramulator.memory_system + + +def make_config(dram_obj, clock_ratio=1, refresh_scope="Rank"): + """Wrap a DRAM object in a single-channel GenericDRAM config for PyTorchSim. + + PyTorchSim creates one Ramulator2 instance per channel, so each config + always has exactly one controller (channel=1 in org is enforced by v2.1). + The wrapper overrides 'frontend' to ExternalFrontEnd automatically. + + refresh_scope: level name for AllBank refresh. + - DDR4 / LPDDR5 / LPDDR5X → "Rank" + - HBM2 / HBM3 → "PseudoChannel" + """ + ctrl = ramulator.controller.GenericDDR( + dram=dram_obj, + scheduler=ramulator.scheduler.FRFCFS(), + refresh_manager=ramulator.refresh_manager.AllBank(scope=refresh_scope), + row_policy=ramulator.row_policy.Open(), + addr_mapper=ramulator.addr_mapper.RoBaRaCoCh(), + ) + ms = ramulator.memory_system.GenericDRAM( + clock_ratio=clock_ratio, + controllers=[ctrl], + # Single-channel per Ramulator2 instance — passthrough maps everything to ch 0 + channel_mapper=ramulator.channel_mapper.PassThroughChannelMapper(), + ) + return { + "frontend": {"impl": "External", "clock_ratio": 1}, + "memory_system": ms.to_config(), + } + + +def gen_hbm2(): + # Available timing presets: HBM2_1600Mbps, HBM2_2000Mbps, HBM2_2400Mbps + # HBM2 has no Rank level — AllBank refresh scope must be PseudoChannel + dram = ramulator.dram.HBM2(org_preset="HBM2_8Gb", timing_preset="HBM2_2000Mbps") + return make_config(dram, clock_ratio=1, refresh_scope="PseudoChannel") + + +def gen_hbm2_tpuv3(): + # TPUv3 HBM2: 900MHz → ~1.8 Gbps. Closest available preset: HBM2_2000Mbps + dram = ramulator.dram.HBM2(org_preset="HBM2_8Gb", timing_preset="HBM2_2000Mbps") + return make_config(dram, clock_ratio=1, refresh_scope="PseudoChannel") + + +def gen_ddr4(): + # Available timing presets — check python/ramulator/dram/ddr4.py + dram = ramulator.dram.DDR4(org_preset="DDR4_8Gb_x8", timing_preset="DDR4_3200AA") + return make_config(dram, clock_ratio=1) + + +def gen_lpddr5(): + dram = ramulator.dram.LPDDR5(org_preset="LPDDR5_8Gb_x16", timing_preset="LPDDR5_6400") + return make_config(dram, clock_ratio=1) + + +def gen_lpddr5x(): + # LPDDR5X_8533: 8533 MT/s, tCK=938ps, CK=1066MHz + dram = ramulator.dram.LPDDR5(org_preset="LPDDR5_8Gb_x16", timing_preset="LPDDR5X_8533") + return make_config(dram, clock_ratio=1) + + +CONFIGS = { + "HBM2.yaml": gen_hbm2, + "HBM2_TPUv3.yaml": gen_hbm2_tpuv3, + "DDR4.yaml": gen_ddr4, + "LPDDR5.yaml": gen_lpddr5, + "LPDDR5X.yaml": gen_lpddr5x, +} + + +if __name__ == "__main__": + out_dir = os.path.dirname(os.path.abspath(__file__)) + for filename, gen_fn in CONFIGS.items(): + cfg = gen_fn() + out_path = os.path.join(out_dir, filename) + with open(out_path, "w") as f: + # json is valid yaml — C++ parse_config_file reads either + json.dump(cfg, f, indent=2) + print(f"Generated {out_path}") + diff --git a/configs/stonne_big_c1_simple_noc.json b/configs/stonne_big_c1_simple_noc.json deleted file mode 100644 index 0a8ca3c2..00000000 --- a/configs/stonne_big_c1_simple_noc.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "core_type" : ["stonne"], - "stonne_config_path" : "/workspace/PyTorchSim/TOGSim/extern/stonneCore/tests/sparseflex_op_128mses_128_bw.cfg", - "num_cores" : 1, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_stonne_per_core" : 8, - "num_stonne_port" : 64, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 8, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycless": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16 -} \ No newline at end of file diff --git a/configs/stonne_big_c1_simple_noc.yml b/configs/stonne_big_c1_simple_noc.yml new file mode 100644 index 00000000..b14838c8 --- /dev/null +++ b/configs/stonne_big_c1_simple_noc.yml @@ -0,0 +1,21 @@ +core_type: +- stonne +stonne_config_path: /workspace/PyTorchSim/TOGSim/extern/stonneCore/tests/sparseflex_op_128mses_128_bw.cfg +num_cores: 1 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_stonne_per_core: 8 +num_stonne_port: 64 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 8 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycless: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 diff --git a/configs/stonne_single_c1_simple_noc.json b/configs/stonne_single_c1_simple_noc.json deleted file mode 100644 index 3421d4f1..00000000 --- a/configs/stonne_single_c1_simple_noc.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "core_type" : ["stonne"], - "stonne_config_path" : "/workspace/PyTorchSim/TOGSim/extern/stonneCore/tests/sparseflex_op_128mses_128_bw.cfg", - "num_cores" : 1, - "core_freq_mhz" : 700, - "core_stats_print_period_cycles" : 10000, - "num_stonne_per_core" : 1, - "num_stonne_port" : 8, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 700, - "dram_channels": 8, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 700, - "icnt_injection_ports_per_core" : 8 -} \ No newline at end of file diff --git a/configs/stonne_single_c1_simple_noc.yml b/configs/stonne_single_c1_simple_noc.yml new file mode 100644 index 00000000..0ed7962c --- /dev/null +++ b/configs/stonne_single_c1_simple_noc.yml @@ -0,0 +1,21 @@ +core_type: +- stonne +stonne_config_path: /workspace/PyTorchSim/TOGSim/extern/stonneCore/tests/sparseflex_op_128mses_128_bw.cfg +num_cores: 1 +core_freq_mhz: 700 +core_stats_print_period_cycles: 10000 +num_stonne_per_core: 1 +num_stonne_port: 8 + +dram_type: ramulator2 +dram_freq_mhz: 700 +dram_channels: 8 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 700 +icnt_injection_ports_per_core: 8 diff --git a/configs/stonne_validation_c1_simple_noc.json b/configs/stonne_validation_c1_simple_noc.json deleted file mode 100644 index fb196dfb..00000000 --- a/configs/stonne_validation_c1_simple_noc.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "core_type" : ["stonne"], - "stonne_config_path" : "/workspace/PyTorchSim/TOGSim/extern/stonneCore/tests/sparseflex_op_128mses_128_bw.cfg", - "num_cores" : 1, - "core_freq_mhz" : 1000, - "core_stats_print_period_cycles" : 10000, - "num_stonne_per_core" : 1, - "num_stonne_port" : 32, - - "dram_type" : "simple", - "dram_freq_mhz" : 1000, - "dram_channels": 1, - "dram_req_size_byte": 32, - "dram_latency" : 100, - "dram_stats_print_period_cycles": 10000, - "l2d_type" : "datacache", - "l2d_config" : "S:128:128:64,32,L:T:m:W:L,A:192:4,32:0,32", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 1000, - "icnt_injection_ports_per_core" : 8 -} \ No newline at end of file diff --git a/configs/stonne_validation_c1_simple_noc.yml b/configs/stonne_validation_c1_simple_noc.yml new file mode 100644 index 00000000..f86dcce1 --- /dev/null +++ b/configs/stonne_validation_c1_simple_noc.yml @@ -0,0 +1,22 @@ +core_type: +- stonne +stonne_config_path: /workspace/PyTorchSim/TOGSim/extern/stonneCore/tests/sparseflex_op_128mses_128_bw.cfg +num_cores: 1 +core_freq_mhz: 1000 +core_stats_print_period_cycles: 10000 +num_stonne_per_core: 1 +num_stonne_port: 32 + +dram_type: simple +dram_freq_mhz: 1000 +dram_channels: 1 +dram_req_size_byte: 32 +dram_latency: 100 +dram_stats_print_period_cycles: 10000 +l2d_type: datacache +l2d_config: S:128:128:64,32,L:T:m:W:L,A:192:4,32:0,32 + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 1000 +icnt_injection_ports_per_core: 8 diff --git a/configs/systolic_ws_128x128_c1_booksim_tpuv2.json b/configs/systolic_ws_128x128_c1_booksim_tpuv2.json deleted file mode 100644 index 686827dc..00000000 --- a/configs/systolic_ws_128x128_c1_booksim_tpuv2.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 700, - "core_stats_print_period_cycles" : 10000, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" :700, - "dram_channels": 16, - "dram_req_size_byte": 32, - - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2.yaml", - - "icnt_type" : "booksim2", - "icnt_freq_mhz" : 700, - "icnt_injection_ports_per_core" : 16, - "booksim_config_path" : "../configs/booksim2_configs/fly_c16_m16.icnt", - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c1_booksim_tpuv2.yml b/configs/systolic_ws_128x128_c1_booksim_tpuv2.yml new file mode 100644 index 00000000..08149005 --- /dev/null +++ b/configs/systolic_ws_128x128_c1_booksim_tpuv2.yml @@ -0,0 +1,26 @@ +num_cores: 1 +core_freq_mhz: 700 +core_stats_print_period_cycles: 10000 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 700 +dram_channels: 16 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2.yaml + +icnt_type: booksim2 +icnt_freq_mhz: 700 +icnt_injection_ports_per_core: 16 +booksim_config_path: ../configs/booksim2_configs/fly_c16_m16.icnt + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c1_booksim_tpuv3.json b/configs/systolic_ws_128x128_c1_booksim_tpuv3.json deleted file mode 100644 index 1109dc0f..00000000 --- a/configs/systolic_ws_128x128_c1_booksim_tpuv3.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 16, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "booksim2", - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - "booksim_config_path" : "../configs/booksim2_configs/fly_c16_m16.icnt", - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} diff --git a/configs/systolic_ws_128x128_c1_booksim_tpuv3.yml b/configs/systolic_ws_128x128_c1_booksim_tpuv3.yml new file mode 100644 index 00000000..12304ce2 --- /dev/null +++ b/configs/systolic_ws_128x128_c1_booksim_tpuv3.yml @@ -0,0 +1,30 @@ +num_cores: 1 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 16 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: booksim2 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 +booksim_config_path: ../configs/booksim2_configs/fly_c16_m16.icnt + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c1_simple_noc_tpuv2.json b/configs/systolic_ws_128x128_c1_simple_noc_tpuv2.json deleted file mode 100644 index 22aedcf8..00000000 --- a/configs/systolic_ws_128x128_c1_simple_noc_tpuv2.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 700, - "core_stats_print_period_cycles" : 10000, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 700, - "dram_channels": 32, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycless": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 700, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c1_simple_noc_tpuv2.yml b/configs/systolic_ws_128x128_c1_simple_noc_tpuv2.yml new file mode 100644 index 00000000..aec29ff8 --- /dev/null +++ b/configs/systolic_ws_128x128_c1_simple_noc_tpuv2.yml @@ -0,0 +1,29 @@ +num_cores: 1 +core_freq_mhz: 700 +core_stats_print_period_cycles: 10000 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 700 +dram_channels: 32 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycless: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 700 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json b/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json deleted file mode 100644 index e8e489d9..00000000 --- a/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 16, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "heuristic", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml b/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml new file mode 100644 index 00000000..72873f1c --- /dev/null +++ b/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml @@ -0,0 +1,30 @@ +num_cores: 1 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 16 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: heuristic +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_half.json b/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_half.json deleted file mode 100644 index 980bfc73..00000000 --- a/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_half.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 8, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_half.yml b/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_half.yml new file mode 100644 index 00000000..c2e962e3 --- /dev/null +++ b/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_half.yml @@ -0,0 +1,30 @@ +num_cores: 1 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 8 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_timing_only.yml b/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_timing_only.yml new file mode 100644 index 00000000..f8ac0a54 --- /dev/null +++ b/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_timing_only.yml @@ -0,0 +1,30 @@ +num_cores: 1 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 16 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 0 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c1_simple_noc_tpuv4.json b/configs/systolic_ws_128x128_c1_simple_noc_tpuv4.json deleted file mode 100644 index 02bfd75c..00000000 --- a/configs/systolic_ws_128x128_c1_simple_noc_tpuv4.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 1050, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 4, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" :1200, - "dram_channels": 16, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2.yaml", - "l2d_type" : "datacache", - "l2d_config" : "S:128:128:512,32,L:T:m:W:L,A:192:4,32:0,32", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 1050, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c1_simple_noc_tpuv4.yml b/configs/systolic_ws_128x128_c1_simple_noc_tpuv4.yml new file mode 100644 index 00000000..0415876d --- /dev/null +++ b/configs/systolic_ws_128x128_c1_simple_noc_tpuv4.yml @@ -0,0 +1,32 @@ +num_cores: 1 +core_freq_mhz: 1050 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 4 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 1200 +dram_channels: 16 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2.yaml +l2d_type: datacache +l2d_config: S:128:128:512,32,L:T:m:W:L,A:192:4,32:0,32 + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 1050 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c2_booksim_tpuv3.json b/configs/systolic_ws_128x128_c2_booksim_tpuv3.json deleted file mode 100644 index 66566324..00000000 --- a/configs/systolic_ws_128x128_c2_booksim_tpuv3.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "num_cores" : 2, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 32, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "booksim2", - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - "booksim_config_path" : "../configs/booksim2_configs/fly_c32_m32.icnt", - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} diff --git a/configs/systolic_ws_128x128_c2_booksim_tpuv3.yml b/configs/systolic_ws_128x128_c2_booksim_tpuv3.yml new file mode 100644 index 00000000..e411c0f3 --- /dev/null +++ b/configs/systolic_ws_128x128_c2_booksim_tpuv3.yml @@ -0,0 +1,30 @@ +num_cores: 2 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 32 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: booksim2 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 +booksim_config_path: ../configs/booksim2_configs/fly_c32_m32.icnt + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c2_booksim_tpuv3_bw_quarter.json b/configs/systolic_ws_128x128_c2_booksim_tpuv3_bw_quarter.json deleted file mode 100644 index 8ef47e87..00000000 --- a/configs/systolic_ws_128x128_c2_booksim_tpuv3_bw_quarter.json +++ /dev/null @@ -1,43 +0,0 @@ -{ - "num_cores" : 2, - "core_freq_mhz" : 940, - "sram_size" : 65536, - "core_print_interval" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq" : 940, - "dram_channels": 8, - "dram_req_size": 32, - "dram_latency" : 10, - "dram_nbl" : 2, - "dram_print_interval": 10000, - "dram_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "booksim2", - "icnt_latency_cycles" : 10, - "icnt_freq" : 940, - "icnt_injection_ports_per_core" : 16, - "icnt_config_path" : "../configs/booksim2_configs/fly_c32_m8.icnt", - - "precision" : 4, - "scheduler" : "simple", - "num_partition" : 2, - "partition": { - "core_0":0, - "core_1":0 - }, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c2_booksim_tpuv3_bw_quarter.yml b/configs/systolic_ws_128x128_c2_booksim_tpuv3_bw_quarter.yml new file mode 100644 index 00000000..f164b108 --- /dev/null +++ b/configs/systolic_ws_128x128_c2_booksim_tpuv3_bw_quarter.yml @@ -0,0 +1,39 @@ +num_cores: 2 +core_freq_mhz: 940 +sram_size: 65536 +core_print_interval: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq: 940 +dram_channels: 8 +dram_req_size: 32 +dram_latency: 10 +dram_nbl: 2 +dram_print_interval: 10000 +dram_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: booksim2 +icnt_latency_cycles: 10 +icnt_freq: 940 +icnt_injection_ports_per_core: 16 +icnt_config_path: ../configs/booksim2_configs/fly_c32_m8.icnt +precision: 4 +scheduler: simple +num_partition: 2 +partition: + core_0: 0 + core_1: 0 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c2_chiplet_tpuv3.json b/configs/systolic_ws_128x128_c2_chiplet_tpuv3.json deleted file mode 100644 index ecd671bf..00000000 --- a/configs/systolic_ws_128x128_c2_chiplet_tpuv3.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "num_cores" : 2, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 32, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "dram_num_partitions" : 2, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "booksim2", - "icnt_freq_mhz" : 1000, - "icnt_injection_ports_per_core" : 16, - "booksim_config_path" : "../configs/booksim2_configs/chiplet_32_32_2.icnt", - "icnt_stats_print_period_cycles" : 10000, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c2_chiplet_tpuv3.yml b/configs/systolic_ws_128x128_c2_chiplet_tpuv3.yml new file mode 100644 index 00000000..e38f091f --- /dev/null +++ b/configs/systolic_ws_128x128_c2_chiplet_tpuv3.yml @@ -0,0 +1,32 @@ +num_cores: 2 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 32 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +dram_num_partitions: 2 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: booksim2 +icnt_freq_mhz: 1000 +icnt_injection_ports_per_core: 16 +booksim_config_path: ../configs/booksim2_configs/chiplet_32_32_2.icnt +icnt_stats_print_period_cycles: 10000 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c2_chiplet_tpuv3_xnuma.json b/configs/systolic_ws_128x128_c2_chiplet_tpuv3_xnuma.json deleted file mode 100644 index 168fbe3a..00000000 --- a/configs/systolic_ws_128x128_c2_chiplet_tpuv3_xnuma.json +++ /dev/null @@ -1,33 +0,0 @@ -{ - "num_cores" : 2, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 32, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "dram_num_partitions" : 1, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "booksim2", - "icnt_freq_mhz" : 1000, - "icnt_injection_ports_per_core" : 16, - "booksim_config_path" : "../configs/booksim2_configs/chiplet_32_32_2.icnt", - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c2_chiplet_tpuv3_xnuma.yml b/configs/systolic_ws_128x128_c2_chiplet_tpuv3_xnuma.yml new file mode 100644 index 00000000..57696243 --- /dev/null +++ b/configs/systolic_ws_128x128_c2_chiplet_tpuv3_xnuma.yml @@ -0,0 +1,31 @@ +num_cores: 2 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 32 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +dram_num_partitions: 1 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: booksim2 +icnt_freq_mhz: 1000 +icnt_injection_ports_per_core: 16 +booksim_config_path: ../configs/booksim2_configs/chiplet_32_32_2.icnt + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json b/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json deleted file mode 100644 index 0a5f15b2..00000000 --- a/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "num_cores" : 2, - "core_freq_mhz" : 700, - "core_stats_print_period_cycles" : 10000, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" :700, - "dram_channels": 32, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 700, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "heuristic", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.yml b/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.yml new file mode 100644 index 00000000..f0686055 --- /dev/null +++ b/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.yml @@ -0,0 +1,29 @@ +num_cores: 2 +core_freq_mhz: 700 +core_stats_print_period_cycles: 10000 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 700 +dram_channels: 32 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 700 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: heuristic +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c2_simple_noc_tpuv3.json b/configs/systolic_ws_128x128_c2_simple_noc_tpuv3.json deleted file mode 100644 index f099b93d..00000000 --- a/configs/systolic_ws_128x128_c2_simple_noc_tpuv3.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "num_cores" : 2, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 32, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "heuristic", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c2_simple_noc_tpuv3.yml b/configs/systolic_ws_128x128_c2_simple_noc_tpuv3.yml new file mode 100644 index 00000000..511a5a09 --- /dev/null +++ b/configs/systolic_ws_128x128_c2_simple_noc_tpuv3.yml @@ -0,0 +1,30 @@ +num_cores: 2 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 32 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: heuristic +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_ils.yml b/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_ils.yml new file mode 100644 index 00000000..ce2d932d --- /dev/null +++ b/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_ils.yml @@ -0,0 +1,33 @@ +# ILS (Instruction-Level Simulation) 전용 config +# - pytorchsim_functional_mode: 0 (timing only, no validation) +# - codegen_mapping_strategy: heuristic (no autotune) +num_cores: 2 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 32 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 0 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: heuristic +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.json b/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.json deleted file mode 100644 index 681ef884..00000000 --- a/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "num_cores" : 2, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 32, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - - "num_partition" : 2, - "partition": { - "core_0":0, - "core_1":1 - }, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} diff --git a/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.yml b/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.yml new file mode 100644 index 00000000..499ad823 --- /dev/null +++ b/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.yml @@ -0,0 +1,34 @@ +num_cores: 2 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 32 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 +num_partition: 2 +partition: + core_0: 0 + core_1: 1 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json b/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json deleted file mode 100644 index d09228a1..00000000 --- a/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "num_cores" : 2, - "core_freq_mhz" : 1050, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 4, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" :1200, - "dram_channels": 32, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2.yaml", - "l2d_type" : "datacache", - "l2d_config" : "S:64:128:512,32,L:B:m:W:L,A:192:4,32:0,32", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 1050, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml b/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml new file mode 100644 index 00000000..da40f01e --- /dev/null +++ b/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml @@ -0,0 +1,32 @@ +num_cores: 2 +core_freq_mhz: 1050 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 4 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 1200 +dram_channels: 32 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2.yaml +l2d_type: datacache +l2d_config: S:64:128:512,32,L:B:m:W:L,A:192:4,32:0,32 + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 1050 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_8x8_c1_booksim.json b/configs/systolic_ws_8x8_c1_booksim.json deleted file mode 100644 index 851664e6..00000000 --- a/configs/systolic_ws_8x8_c1_booksim.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 800, - "core_stats_print_period_cycles" : 100000, - - "vpu_num_lanes" : 8, - "vpu_spad_size_kb_per_lane" : 32, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" :800, - "dram_channels": 1, - "dram_req_size_byte": 64, - "dram_num_burst_length" : 4, - "dram_stats_print_period_cycles": 100000, - "ramulator_config_path" : "../configs/ramulator2_configs/DDR4.yaml", - - "icnt_type" : "booksim2", - "icnt_freq_mhz" : 800, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_8x8_c1_booksim.yml b/configs/systolic_ws_8x8_c1_booksim.yml new file mode 100644 index 00000000..6fd305f9 --- /dev/null +++ b/configs/systolic_ws_8x8_c1_booksim.yml @@ -0,0 +1,27 @@ +num_cores: 1 +core_freq_mhz: 800 +core_stats_print_period_cycles: 100000 + +vpu_num_lanes: 8 +vpu_spad_size_kb_per_lane: 32 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 800 +dram_channels: 1 +dram_req_size_byte: 64 +dram_num_burst_length: 4 +dram_stats_print_period_cycles: 100000 +ramulator_config_path: ../configs/ramulator2_configs/DDR4.yaml + +icnt_type: booksim2 +icnt_freq_mhz: 800 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/configs/systolic_ws_8x8_c1_simple_noc.json b/configs/systolic_ws_8x8_c1_simple_noc.json deleted file mode 100644 index 2eb7e183..00000000 --- a/configs/systolic_ws_8x8_c1_simple_noc.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 800, - "core_stats_print_period_cycles" : 100000, - - "vpu_num_lanes" : 8, - "vpu_spad_size_kb_per_lane" : 32, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" :800, - "dram_channels": 1, - "dram_req_size_byte": 64, - "dram_num_burst_length" : 4, - "dram_stats_print_period_cycles": 100000, - "ramulator_config_path" : "../configs/ramulator2_configs/DDR4.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 800, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/configs/systolic_ws_8x8_c1_simple_noc.yml b/configs/systolic_ws_8x8_c1_simple_noc.yml new file mode 100644 index 00000000..274f633c --- /dev/null +++ b/configs/systolic_ws_8x8_c1_simple_noc.yml @@ -0,0 +1,28 @@ +num_cores: 1 +core_freq_mhz: 800 +core_stats_print_period_cycles: 100000 + +vpu_num_lanes: 8 +vpu_spad_size_kb_per_lane: 32 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 800 +dram_channels: 1 +dram_req_size_byte: 64 +dram_num_burst_length: 4 +dram_stats_print_period_cycles: 100000 +ramulator_config_path: ../configs/ramulator2_configs/DDR4.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 800 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/experiments/BERT.py b/experiments/BERT.py index 3311682c..12e3cb33 100644 --- a/experiments/BERT.py +++ b/experiments/BERT.py @@ -1,57 +1,42 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime -def run_BERT(size, input_seq, config): - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - # from tests.test_transformer import EncoderBlock - from tests.Fusion.test_transformer_fusion import EncoderBlock - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) - hidden_dim = {'base': 768, 'large': 1024, 'xlarge': 2048} - embedding_size = {'base': 768, 'large': 1024, 'xlarge': 2048} - heads = {'base': 12, 'large': 16, 'xlarge': 32} # hidden/64 https://arxiv.org/pdf/1909.11942 - cpu_query = torch.randn(input_seq, hidden_dim[size]) - encoder_block = EncoderBlock(embedding_size[size], heads[size]).eval() - - query = cpu_query.clone().to(device=device) - opt_fn = torch.compile(dynamic=False)(encoder_block.to(device=device)) +import torch +from Simulator.simulator import TOGSimulator - SchedulerDNNModel.register_model(f"BERT-{size}", opt_fn) - request = Request(f"BERT-{size}", [query], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_timing_only.yml') +os.environ['TOGSIM_CONFIG'] = config - # Run scheduler - while not scheduler.is_finished(): - with torch.no_grad(): - scheduler.schedule() +# Try Fusion EncoderBlock first, fall back to standard test_transformer +try: + from tests.Fusion.test_transformer_fusion import EncoderBlock +except ImportError: + from tests.test_transformer import EncoderBlock - print(f"BERT-{size} Simulation Done") +HIDDEN_DIM = {'base': 768, 'large': 1024, 'xlarge': 2048} +EMBEDDING_SIZE = {'base': 768, 'large': 1024, 'xlarge': 2048} +HEADS = {'base': 12, 'large': 16, 'xlarge': 32} if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path FIXME: gem5 result is different as directoy name - sys.path.append(base_dir) args = argparse.ArgumentParser() - args.add_argument('--size', type=str, default='base') - args.add_argument('--dump_path', type=str, default='results') + args.add_argument('--size', type=str, default='base', choices=['base', 'large', 'xlarge']) args.add_argument('--input_size', type=int, default=512) args = args.parse_args() - size = args.size - input_seq = args.input_size - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"BERT_{size}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_DUMP_PATH'] = result_path - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] - - run_BERT(size, input_seq, config) + + hidden_dim = HIDDEN_DIM[args.size] + embedding_size = EMBEDDING_SIZE[args.size] + heads = HEADS[args.size] + + device = torch.device("npu:0") + model = EncoderBlock(embedding_size, heads).eval().to(device=device) + model_input = torch.randn(args.input_size, hidden_dim).to(device=device) + opt_fn = torch.compile(dynamic=False)(model) + + with TOGSimulator(config_path=config), torch.no_grad(): + torch.npu.launch_model(opt_fn, model_input, stream_index=0, timestamp=0) + torch.npu.synchronize() + print(f"BERT-{args.size} Simulation Done") diff --git a/experiments/artifact/cycle_validation/run_cycle.sh b/experiments/artifact/cycle_validation/run_cycle.sh index 99eed4ed..e49538d0 100755 --- a/experiments/artifact/cycle_validation/run_cycle.sh +++ b/experiments/artifact/cycle_validation/run_cycle.sh @@ -1,85 +1,153 @@ #!/bin/bash set -e -export TORCHSIM_CONFIG=$TORCHSIM_DIR/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json +usage() { + cat <<'EOF' +Usage: run_cycle.sh [--only SECTION[,SECTION...]] + + Run cycle validation benchmarks. Default: all sections + summary. + + SECTION (comma-separated for --only): + matmul GEMM sizes + conv Conv2d sizes + layernorm LayerNorm sizes + softmax Softmax sizes + attention Attention sizes + resnet resnet18, resnet50 + bert BERT base/large/xlarge + summary summary_cycle.py (reads logs under experiments/artifact/logs) + +Examples: + ./run_cycle.sh + ./run_cycle.sh --only matmul + ./run_cycle.sh --only matmul,conv,summary +EOF +} + +ONLY="" +while [[ $# -gt 0 ]]; do + case "$1" in + --only) + ONLY="${2:-}" + if [[ -z "$ONLY" ]]; then echo "error: --only needs a value"; exit 1; fi + shift 2 + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "error: unknown argument: $1" >&2 + usage >&2 + exit 1 + ;; + esac +done + +# If ONLY is set, run section NAME only when ",$NAME," appears in ",$ONLY," +should_run() { + local name=$1 + if [[ -z "$ONLY" ]]; then + return 0 + fi + [[ ",${ONLY}," == *",${name},"* ]] +} + +export TOGSIM_CONFIG=$TORCHSIM_DIR/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_timing_only.yml LOG_DIR=$TORCHSIM_DIR/experiments/artifact/logs mkdir -p $LOG_DIR # Matmul -for sz in "256 256 256" "512 512 512" "1024 1024 1024" "2048 2048 2048"; do - name="gemm_${sz// /x}" - echo "" - echo "===================================================" - echo "[*] Running Matmul size=$sz" - echo "===================================================" - python3 $TORCHSIM_DIR/experiments/gemm.py --size $sz | tee $LOG_DIR/${name}.log -done +if should_run matmul; then + for sz in "256 256 256" "512 512 512" "1024 1024 1024" "2048 2048 2048"; do + name="gemm_${sz// /x}" + echo "" + echo "===================================================" + echo "[*] Running Matmul size=$sz" + echo "===================================================" + python3 $TORCHSIM_DIR/experiments/gemm.py --size $sz | tee $LOG_DIR/${name}.log + done +fi # Conv -for sz in \ - "1 56 56 64 64 3 1 1" \ - "1 28 28 128 128 3 1 1" \ - "1 14 14 256 256 3 1 1" \ - "1 7 7 512 512 3 1 1" \ - "64 56 56 64 64 3 1 1" \ - "64 28 28 128 128 3 1 1" \ - "64 14 14 256 256 3 1 1" \ - "64 7 7 512 512 3 1 1"; do - name="conv_${sz// /x}" - echo "" - echo "===================================================" - echo "[*] Running Conv size=$sz" - echo "===================================================" - python3 $TORCHSIM_DIR/experiments/conv.py --size $sz | tee $LOG_DIR/${name}.log -done - -# Attention -for sz in "12 512 64" "16 512 64" "32 512 64"; do - name="attention_${sz// /x}" - echo "" - echo "===================================================" - echo "[*] Running Attention size=$sz" - echo "===================================================" - python3 $TORCHSIM_DIR/experiments/attention.py --size $sz | tee $LOG_DIR/${name}.log -done +if should_run conv; then + for sz in \ + "1 56 56 64 64 3 1 1" \ + "1 28 28 128 128 3 1 1" \ + "1 14 14 256 256 3 1 1" \ + "1 7 7 512 512 3 1 1" \ + "64 56 56 64 64 3 1 1" \ + "64 28 28 128 128 3 1 1" \ + "64 14 14 256 256 3 1 1" \ + "64 7 7 512 512 3 1 1"; do + name="conv_${sz// /x}" + echo "" + echo "===================================================" + echo "[*] Running Conv size=$sz" + echo "===================================================" + python3 $TORCHSIM_DIR/experiments/conv.py --size $sz | tee $LOG_DIR/${name}.log + done +fi # LayerNorm -for sz in "512 768" "2048 768" "8192 768"; do - name="layernorm_${sz// /x}" - echo "" - echo "===================================================" - echo "[*] Running LayerNorm size=$sz" - echo "===================================================" - python3 $TORCHSIM_DIR/experiments/layernorm.py --size $sz | tee $LOG_DIR/${name}.log -done +if should_run layernorm; then + for sz in "512 768" "2048 768" "8192 768"; do + name="layernorm_${sz// /x}" + echo "" + echo "===================================================" + echo "[*] Running LayerNorm size=$sz" + echo "===================================================" + python3 $TORCHSIM_DIR/experiments/layernorm.py --size $sz | tee $LOG_DIR/${name}.log + done +fi # Softmax -for sz in "512 512" "2048 2048" "8192 8192"; do - name="softmax_${sz// /x}" - echo "" - echo "===================================================" - echo "[*] Running Softmax size=$sz" - echo "===================================================" - python3 $TORCHSIM_DIR/experiments/softmax.py --size $sz | tee $LOG_DIR/${name}.log -done +if should_run softmax; then + for sz in "512 512" "2048 2048" "8192 8192"; do + name="softmax_${sz// /x}" + echo "" + echo "===================================================" + echo "[*] Running Softmax size=$sz" + echo "===================================================" + python3 $TORCHSIM_DIR/experiments/softmax.py --size $sz | tee $LOG_DIR/${name}.log + done +fi + +# Attention +if should_run attention; then + for sz in "12 512 64" "16 512 64" "32 512 64"; do + name="attention_${sz// /x}" + echo "" + echo "===================================================" + echo "[*] Running Attention size=$sz" + echo "===================================================" + python3 $TORCHSIM_DIR/experiments/attention.py --size $sz | tee $LOG_DIR/${name}.log + done +fi # ResNet -for model in "resnet18" "resnet50"; do - echo "" - echo "===================================================" - echo "[*] Running $model" - echo "===================================================" - python3 $TORCHSIM_DIR/experiments/${model}.py | tee $LOG_DIR/${model}.log -done +if should_run resnet; then + for model in "resnet18" "resnet50"; do + echo "" + echo "===================================================" + echo "[*] Running $model" + echo "===================================================" + python3 $TORCHSIM_DIR/experiments/${model}.py | tee $LOG_DIR/${model}.log + done +fi # BERT -for model in "base" "large" "xlarge"; do - echo "" - echo "===================================================" - echo "[*] Running BERT size=$model" - echo "===================================================" - python3 $TORCHSIM_DIR/experiments/BERT.py --size $model | tee $LOG_DIR/bert_${model}.log -done +if should_run bert; then + for model in "base" "large" "xlarge"; do + echo "" + echo "===================================================" + echo "[*] Running BERT size=$model" + echo "===================================================" + python3 $TORCHSIM_DIR/experiments/BERT.py --size $model | tee $LOG_DIR/bert_${model}.log + done +fi # Cycle Summary -python3 $TORCHSIM_DIR/experiments/artifact/cycle_validation/summary_cycle.py | tee "$TORCHSIM_DIR/experiments/artifact/cycle_validation/summary_cycle.out" \ No newline at end of file +if should_run summary; then + python3 $TORCHSIM_DIR/experiments/artifact/cycle_validation/summary_cycle.py | tee "$TORCHSIM_DIR/experiments/artifact/cycle_validation/summary_cycle.out" +fi diff --git a/experiments/artifact/speedup/run_speedup.sh b/experiments/artifact/speedup/run_speedup.sh index 9a19e9af..cb5ee511 100755 --- a/experiments/artifact/speedup/run_speedup.sh +++ b/experiments/artifact/speedup/run_speedup.sh @@ -1,11 +1,15 @@ #!/bin/bash +set -e + LOG_DIR=$TORCHSIM_DIR/experiments/artifact/logs CONFIG_DIR="$TORCHSIM_DIR/configs" -SIMULATOR_BIN="$TORCHSIM_DIR/TOGSim/build/bin/Simulator" +EXTRACT_TRACE="$TORCHSIM_DIR/experiments/artifact/speedup/scripts/extract_trace_from_log.py" +TRACE_CACHE_DIR="$TORCHSIM_DIR/experiments/artifact/speedup/trace_cache" +mkdir -p "$TRACE_CACHE_DIR" configs=( - "systolic_ws_128x128_c2_simple_noc_tpuv3.json" - "systolic_ws_128x128_c2_booksim_tpuv3.json" + "systolic_ws_128x128_c2_simple_noc_tpuv3.yml" + "systolic_ws_128x128_c2_booksim_tpuv3.yml" ) target_list=( @@ -25,9 +29,11 @@ output_dir="$TORCHSIM_DIR/experiments/artifact/speedup/results" mkdir -p "$output_dir" echo "[*] Scanning log files in: $LOG_DIR" +echo "[*] Extracting [TOGSim] Run command and trace from logs" echo "" for log_file in "$LOG_DIR"/*.log; do + [[ -f "$log_file" ]] || continue filename=$(basename "$log_file") workload="${filename%.log}" @@ -36,45 +42,38 @@ for log_file in "$LOG_DIR"/*.log; do fi echo "==> Workload: $workload" - declare -a ONNX_ATTR_PAIRS=() + # === Extract [TOGSim] Run command from log === + base_cmd=$(grep "\[TOGSim\] Run command:" "$log_file" 2>/dev/null | sed 's/.*\[TOGSim\] Run command: //' | head -1) + if [[ -z "$base_cmd" ]]; then + echo " Skipping: no [TOGSim] Run command found in $log_file" + continue + fi - # === Grep launch line === - while IFS= read -r line; do - if [[ "$line" == launch* ]]; then - read -r _ onnx_path attr_path _ <<< "$line" - ONNX_ATTR_PAIRS+=("$onnx_path|$attr_path") - fi - done < "$log_file" + # === Get trace file (replace FIFO in command; stored trace or generate from log) === + trace_file=$(python3 "$EXTRACT_TRACE" "$log_file" "$TRACE_CACHE_DIR/${workload}.trace" 2>/dev/null) || true + if [[ -z "$trace_file" || ! -f "$trace_file" ]]; then + echo " Skipping: could not extract trace from $log_file" + continue + fi # Normal configs for config in "${configs[@]}"; do - output_file="$output_dir/${workload}_${config}.txt" - echo "Running with config=$config" - echo "===== config=$config | model=$workload =====" >> "$output_file" + output_file="$output_dir/${workload}_${config}.txt" + echo "===== config=$config | model=$workload =====" > "$output_file" sum_all_iters=0.0 iter_count=0 - # === Run 5 iterations === for iter in {1..5}; do echo "[Iter $iter] Running simulation for workload=$workload config=$config" - cmd="" - for pair in "${ONNX_ATTR_PAIRS[@]}"; do - IFS="|" read -r onnx_path attr_path <<< "$pair" - cmd+=" $SIMULATOR_BIN --config $CONFIG_DIR/$config --models_list $onnx_path --attributes_list $attr_path;" - done - - output=$(bash -c "$cmd") - sim_times=$(echo "$output" | grep "Simulation time:" | sed -E 's/.*Simulation time: ([0-9]+\.[0-9]+).*/\1/') - - if [[ -n "$sim_times" ]]; then - sum_per_iter=0.0 - while IFS= read -r sim_time; do - echo "Iteration $iter: simulation_time = $sim_time" >> "$output_file" - sum_per_iter=$(awk -v a="$sum_per_iter" -v b="$sim_time" 'BEGIN {printf "%.6f", a + b}') - done <<< "$sim_times" - - echo "Iteration $iter: total_simulation_time = $sum_per_iter" >> "$output_file" - sum_all_iters=$(awk -v a="$sum_all_iters" -v b="$sum_per_iter" 'BEGIN {printf "%.6f", a + b}') + # Build command: replace --config and --models_list in base_cmd with our config and trace + cmd=$(echo "$base_cmd" | sed -E "s|--config [^ ]+|--config $CONFIG_DIR/$config|" | sed -E "s|--models_list [^ ]+|--models_list $trace_file|") + echo "$cmd" + output=$(bash -c "$cmd" 2>&1) || true + sim_time=$(echo "$output" | grep "Wall-clock time for simulation:" | sed -E 's/.*Wall-clock time for simulation: ([0-9]+\.[0-9]+) seconds.*/\1/') + + if [[ -n "$sim_time" ]]; then + echo "Iteration $iter: simulation_time = $sim_time" >> "$output_file" + sum_all_iters=$(awk -v a="$sum_all_iters" -v b="$sim_time" 'BEGIN {printf "%.6f", a + b}') iter_count=$((iter_count + 1)) else echo "Iteration $iter: No simulation time found." >> "$output_file" diff --git a/experiments/artifact/speedup/scripts/run_speed_ils_bert.sh b/experiments/artifact/speedup/scripts/run_speed_ils_bert.sh index fe872e02..642fec34 100755 --- a/experiments/artifact/speedup/scripts/run_speed_ils_bert.sh +++ b/experiments/artifact/speedup/scripts/run_speed_ils_bert.sh @@ -2,10 +2,7 @@ base_dir=$TORCHSIM_DIR/experiments/artifact/speedup config=( - # "systolic_ws_8x8_c1_simple_noc.json" - "systolic_ws_128x128_c2_simple_noc_tpuv3.json" - #"systolic_ws_128x128_c2_booksim_tpuv3.json" - # "systolic_ws_128x128_c2_simple_noc_tpuv4.json" + "systolic_ws_128x128_c2_simple_noc_tpuv3_ils.yml" ) TIMESTAMP=$(date +"%Y-%m-%d_%H-%M-%S") SIZE_LIST=( @@ -31,15 +28,11 @@ for i in "${config[@]}"; do for iter in {1..5}; do echo "[Iter $iter] Running simulation for workload=ils_$ops config=$config" output=$(bash -c " - export TORCHSIM_TLS_MODE=0; - export TORCHSIM_VALIDATION_MODE=0; - export TORCHSIM_CONFIG=$config_path; - export AUTOTUNE=0; - printenv; - python3 $workload 2> /dev/null | $TORCHSIM_DIR/experiments/artifact/speedup/scripts/ils_parser.sh + export TOGSIM_CONFIG=$config_path; + cd $TORCHSIM_DIR && python3 $workload 2>&1 ") - sim_time=$(echo "$output" | grep "Simulation time:" | tail -n 1 | sed -E 's/.*Simulation time: ([0-9]+\.[0-9]+).*/\1/') + sim_time=$(echo "$output" | grep "Wall-clock time for simulation:" | tail -n 1 | sed -E 's/.*Wall-clock time for simulation: ([0-9]+\.[0-9]+) seconds.*/\1/') if [[ -n "$sim_time" ]]; then echo "Iteration $iter: Simulation time = $sim_time" diff --git a/experiments/artifact/speedup/scripts/run_speed_ils_conv.sh b/experiments/artifact/speedup/scripts/run_speed_ils_conv.sh index 19613a34..f5602668 100755 --- a/experiments/artifact/speedup/scripts/run_speed_ils_conv.sh +++ b/experiments/artifact/speedup/scripts/run_speed_ils_conv.sh @@ -2,10 +2,7 @@ base_dir=$TORCHSIM_DIR/experiments/artifact/speedup config=( - # "systolic_ws_8x8_c1_simple_noc.json" - "systolic_ws_128x128_c2_simple_noc_tpuv3.json" - #"systolic_ws_128x128_c2_booksim_tpuv3.json" - # "systolic_ws_128x128_c2_simple_noc_tpuv4.json" + "systolic_ws_128x128_c2_simple_noc_tpuv3_ils.yml" ) TIMESTAMP=$(date +"%Y-%m-%d_%H-%M-%S") SHAPE_LIST=( @@ -32,15 +29,11 @@ for i in "${config[@]}"; do for iter in {1..5}; do echo "[Iter $iter] Running simulation for workload=ils_$ops config=$config" output=$(bash -c " - export TORCHSIM_TLS_MODE=0; - export TORCHSIM_VALIDATION_MODE=0; - export TORCHSIM_CONFIG=$config_path; - export AUTOTUNE=0; - printenv; - python3 $workload 2> /dev/null | $TORCHSIM_DIR/experiments/artifact/speedup/scripts/ils_parser.sh + export TOGSIM_CONFIG=$config_path; + cd $TORCHSIM_DIR && python3 $workload 2>&1 ") - sim_time=$(echo "$output" | grep "Simulation time:" | tail -n 1 | sed -E 's/.*Simulation time: ([0-9]+\.[0-9]+).*/\1/') + sim_time=$(echo "$output" | grep "Wall-clock time for simulation:" | tail -n 1 | sed -E 's/.*Wall-clock time for simulation: ([0-9]+\.[0-9]+) seconds.*/\1/') if [[ -n "$sim_time" ]]; then echo "Iteration $iter: Simulation time = $sim_time" diff --git a/experiments/artifact/speedup/scripts/run_speed_ils_matmul.sh b/experiments/artifact/speedup/scripts/run_speed_ils_matmul.sh index 6f3385f1..bc912aa6 100755 --- a/experiments/artifact/speedup/scripts/run_speed_ils_matmul.sh +++ b/experiments/artifact/speedup/scripts/run_speed_ils_matmul.sh @@ -2,10 +2,7 @@ base_dir=$TORCHSIM_DIR/experiments/artifact/speedup config=( - # "systolic_ws_8x8_c1_simple_noc.json" - "systolic_ws_128x128_c2_simple_noc_tpuv3.json" - #"systolic_ws_128x128_c2_booksim_tpuv3.json" - # "systolic_ws_128x128_c2_simple_noc_tpuv4.json" + "systolic_ws_128x128_c2_simple_noc_tpuv3_ils.yml" ) TIMESTAMP=$(date +"%Y-%m-%d_%H-%M-%S") SHAPE_LIST=( @@ -30,15 +27,11 @@ for i in "${config[@]}"; do for iter in {1..5}; do echo "[Iter $iter] Running simulation for workload=ils_$ops config=$config" output=$(bash -c " - export TORCHSIM_TLS_MODE=0; - export TORCHSIM_VALIDATION_MODE=1; - export TORCHSIM_CONFIG=$config_path; - export AUTOTUNE=0; - printenv; - python3 $workload 2> /dev/null | $TORCHSIM_DIR/experiments/artifact/speedup/scripts/ils_parser.sh + export TOGSIM_CONFIG=$config_path; + cd $TORCHSIM_DIR && python3 $workload 2>&1 ") - sim_time=$(echo "$output" | grep "Simulation time:" | tail -n 1 | sed -E 's/.*Simulation time: ([0-9]+\.[0-9]+).*/\1/') + sim_time=$(echo "$output" | grep "Wall-clock time for simulation:" | tail -n 1 | sed -E 's/.*Wall-clock time for simulation: ([0-9]+\.[0-9]+) seconds.*/\1/') if [[ -n "$sim_time" ]]; then echo "Iteration $iter: simulation_time = $sim_time" >> "$output_file" diff --git a/experiments/artifact/speedup/scripts/run_speed_ils_resnet.sh b/experiments/artifact/speedup/scripts/run_speed_ils_resnet.sh index ca4cfa39..b1a43cb5 100755 --- a/experiments/artifact/speedup/scripts/run_speed_ils_resnet.sh +++ b/experiments/artifact/speedup/scripts/run_speed_ils_resnet.sh @@ -2,10 +2,7 @@ base_dir=$TORCHSIM_DIR/experiments/artifact/speedup config=( - # "systolic_ws_8x8_c1_simple_noc.json" - "systolic_ws_128x128_c2_simple_noc_tpuv3.json" - #"systolic_ws_128x128_c2_booksim_tpuv3.json" - # "systolic_ws_128x128_c2_simple_noc_tpuv4.json" + "systolic_ws_128x128_c2_simple_noc_tpuv3_ils.yml" ) TIMESTAMP=$(date +"%Y-%m-%d_%H-%M-%S") SIZE_LIST=( @@ -38,15 +35,11 @@ for i in "${config[@]}"; do for iter in {1..5}; do echo "[Iter $iter] Running simulation for workload=ils_$ops config=$config" output=$(bash -c " - export TORCHSIM_TLS_MODE=0; - export TORCHSIM_VALIDATION_MODE=0; - export TORCHSIM_CONFIG=$config_path; - export AUTOTUNE=0; - printenv; - python3 $workload 2> /dev/null | $TORCHSIM_DIR/experiments/artifact/speedup/scripts/ils_parser.sh + export TOGSIM_CONFIG=$config_path; + cd $TORCHSIM_DIR && python3 $workload 2>&1 ") - sim_time=$(echo "$output" | grep "Simulation time:" | tail -n 1 | sed -E 's/.*Simulation time: ([0-9]+\.[0-9]+).*/\1/') + sim_time=$(echo "$output" | grep "Wall-clock time for simulation:" | tail -n 1 | sed -E 's/.*Wall-clock time for simulation: ([0-9]+\.[0-9]+) seconds.*/\1/') if [[ -n "$sim_time" ]]; then echo "Iteration $iter: Simulation time = $sim_time" diff --git a/experiments/attention.py b/experiments/attention.py index bbd2734e..db0f45bb 100644 --- a/experiments/attention.py +++ b/experiments/attention.py @@ -1,56 +1,36 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys +import math import argparse -import datetime +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) -def run_attention(size, config): - def attention(query, key, value): - import math - d_k = query.size(-1) - scores = torch.matmul(key, query.transpose(-2, -1)) / math.sqrt(d_k) - p_attn = scores.softmax(dim=-2) - return torch.matmul(value.transpose(-1, -2), p_attn) - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() - query = torch.randn(size).to(device=device) - key = torch.randn(size).to(device=device) - value = torch.randn(size).to(device=device) - opt_fn = torch.compile(dynamic=False)(attention) - - SchedulerDNNModel.register_model("attention", opt_fn) - request = Request("attention", [query, key, value], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) +import torch +from Simulator.simulator import TOGSimulator - # Run scheduler - while not scheduler.is_finished(): - with torch.no_grad(): - scheduler.schedule() +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml') +os.environ['TOGSIM_CONFIG'] = config - print(f"Attention {str(size)} Simulation Done") +def attention(query, key, value): + d_k = query.size(-1) + scores = torch.matmul(key, query.transpose(-2, -1)) / math.sqrt(d_k) + p_attn = scores.softmax(dim=-2) + return torch.matmul(value.transpose(-1, -2), p_attn) if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) args = argparse.ArgumentParser() args.add_argument('--size', nargs='+', type=int, default=[12, 512, 64], help='Tensor Shape') - args.add_argument('--dump_path', type=str, default='results') args = args.parse_args() - size = args.size - size_str = "x".join([str(i) for i in size]) - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"attention_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_DUMP_PATH'] = result_path - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] + size = tuple(args.size) + + device = torch.device("npu:0") + query = torch.randn(*size).to(device=device) + key = torch.randn(*size).to(device=device) + value = torch.randn(*size).to(device=device) + opt_fn = torch.compile(dynamic=False)(attention) - run_attention(size, config) + with TOGSimulator(config_path=config), torch.no_grad(): + torch.npu.launch_model(opt_fn, query, key, value, stream_index=0, timestamp=0) + torch.npu.synchronize() + print(f"Attention {size} Simulation Done") diff --git a/experiments/conv.py b/experiments/conv.py index f439c5e3..65e52635 100644 --- a/experiments/conv.py +++ b/experiments/conv.py @@ -1,57 +1,39 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) + +import torch +from Simulator.simulator import TOGSimulator + +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') +os.environ['TOGSIM_CONFIG'] = config -def run_conv2d(batch_size, i_h, i_w, i_c, o_c, kernel_size, stride, padding, config): - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - def custom_conv2d(a, b, bias): - i_c = a.shape[1] - o_c = b.shape[0] - conv2d = torch.nn.Conv2d(i_c, o_c, b.shape[-1], stride=stride, padding=padding, dilation=1, bias=False) +def conv2d_fn(batch_size, i_h, i_w, i_c, o_c, kernel_size, stride, padding): + def _conv(a, b, bias): + conv2d = torch.nn.Conv2d(i_c, o_c, kernel_size, stride=stride, padding=padding, dilation=1, bias=False) conv2d.weight = torch.nn.Parameter(b) - # conv2d.bias = torch.nn.Parameter(bias) return conv2d(a) - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() + return _conv + +if __name__ == "__main__": + args = argparse.ArgumentParser() + args.add_argument('--size', nargs='+', type=int, default=[8, 28, 28, 128, 128, 3, 1, 1], + help='B H W I_C O_C K S P') + args = args.parse_args() + batch_size, i_h, i_w, i_c, o_c, kernel_size, stride, padding = args.size + + device = torch.device("npu:0") conv_input = torch.randn(batch_size, i_c, i_h, i_w).to(memory_format=torch.channels_last, device=device) conv_kernel = torch.randn(o_c, i_c, kernel_size, kernel_size).to(memory_format=torch.channels_last, device=device) conv_bias = torch.randn(o_c).to(device=device) - opt_fn = torch.compile(dynamic=False)(custom_conv2d) - - SchedulerDNNModel.register_model("CONV", opt_fn) - request = Request("CONV", [conv_input, conv_kernel, conv_bias], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) - - # Run scheduler - while not scheduler.is_finished(): - with torch.no_grad(): - scheduler.schedule() - print(f"CONV {batch_size}_{i_h}_{i_w}_{i_c}_{o_c}_{kernel_size}_{stride}_{padding} (B_H_W_I_C_O_C_K_S_P) Simulation Done") + custom_conv = conv2d_fn(batch_size, i_h, i_w, i_c, o_c, kernel_size, stride, padding) + opt_fn = torch.compile(dynamic=False)(custom_conv) -if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) - args = argparse.ArgumentParser() - args.add_argument('--size', nargs='+', type=int, default=[8, 28, 28, 128, 128, 3, 1, 1], help='B H W I_C O_C K S P') - args.add_argument('--dump_path', type=str, default='results') - args = args.parse_args() - size = args.size - size_str = "_".join([str(i) for i in size]) - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"CONV_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_DUMP_PATH'] = result_path - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] - - run_conv2d(size[0], size[1], size[2], size[3], size[4], size[5], size[6], size[7], config) \ No newline at end of file + with TOGSimulator(config_path=config), torch.no_grad(): + torch.npu.launch_model(opt_fn, conv_input, conv_kernel, conv_bias, stream_index=0, timestamp=0) + torch.npu.synchronize() + print(f"CONV {batch_size}_{i_h}_{i_w}_{i_c}_{o_c}_{kernel_size}_{stride}_{padding} Simulation Done") diff --git a/experiments/gemm.py b/experiments/gemm.py index e92200d1..dbbba3ea 100644 --- a/experiments/gemm.py +++ b/experiments/gemm.py @@ -1,54 +1,32 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime - -def run_matmul(input_size, hidden_size, output_size, config): - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - def custom_matmul(a, b): - return torch.matmul(a, b) - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() - torch.manual_seed(0) - input = torch.randn(input_size, hidden_size).to(device=device) - weight = torch.randn(hidden_size, output_size).to(device=device) - opt_fn = torch.compile(dynamic=False)(custom_matmul) +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) - SchedulerDNNModel.register_model("GEMM", opt_fn) - request = Request("GEMM", [input, weight], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) +import torch +from Simulator.simulator import TOGSimulator - # Run scheduler - while not scheduler.is_finished(): - scheduler.schedule() +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') +os.environ['TOGSIM_CONFIG'] = config - print(f"GEMM {input_size}x{hidden_size}x{output_size} (MxKxN) Simulation Done") +def matmul_fn(a, b): + return torch.matmul(a, b) if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) args = argparse.ArgumentParser() args.add_argument('--size', nargs='+', type=int, default=[128, 128, 128], help='M K N') - args.add_argument('--dump_path', type=str, default='results') args = args.parse_args() - size = args.size - size_str = "x".join([str(i) for i in size]) - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"GEMM_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_DUMP_PATH'] = result_path - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] + M, K, N = args.size[0], args.size[1], args.size[2] - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() - run_matmul(size[0], size[1], size[2], config) + device = torch.device("npu:0") + torch.manual_seed(0) + input_a = torch.randn(M, K).to(device=device) + input_b = torch.randn(K, N).to(device=device) + opt_fn = torch.compile(dynamic=False)(matmul_fn) + + with TOGSimulator(config_path=config), torch.no_grad(): + torch.npu.launch_model(opt_fn, input_a, input_b, stream_index=0, timestamp=0) + torch.npu.synchronize() + print(f"GEMM {M}x{K}x{N} (MxKxN) Simulation Done") diff --git a/experiments/layernorm.py b/experiments/layernorm.py index 74b6d286..375f98e9 100644 --- a/experiments/layernorm.py +++ b/experiments/layernorm.py @@ -1,48 +1,29 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime - -def run_layernorm(size, config): - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() - input = torch.randn(size).to(device=device) - opt_fn = torch.compile(dynamic=False)(torch.nn.LayerNorm(size[-1]).to(device=device)) +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) - SchedulerDNNModel.register_model("LayerNorm", opt_fn) - request = Request("LayerNorm", [input], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) - - # Run scheduler - while not scheduler.is_finished(): - scheduler.schedule() +import torch +from Simulator.simulator import TOGSimulator - print(f"LayerNorm {str(size)} Simulation Done") +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') +os.environ['TOGSIM_CONFIG'] = config if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) args = argparse.ArgumentParser() args.add_argument('--size', nargs='+', type=int, default=[512, 768], help='Tensor Shape') - args.add_argument('--dump_path', type=str, default='results') args = args.parse_args() - size = args.size - size_str = "x".join([str(i) for i in size]) - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"LayerNorm_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_DUMP_PATH'] = result_path - os.environ['TORCHSIM_FUSION_REDUCTION_REDUCTION'] = "0" - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] - - run_layernorm(size, config) + size = tuple(args.size) + normalized_shape = size[-1] + + device = torch.device("npu:0") + model = torch.nn.LayerNorm(normalized_shape).to(device=device) + opt_fn = torch.compile(dynamic=False)(model) + model_input = torch.randn(*size).to(device=device) + + with TOGSimulator(config_path=config), torch.no_grad(): + torch.npu.launch_model(opt_fn, model_input, stream_index=0, timestamp=0) + torch.npu.synchronize() + print(f"LayerNorm {size} Simulation Done") diff --git a/experiments/resnet18.py b/experiments/resnet18.py index 45311d59..ffec9a50 100644 --- a/experiments/resnet18.py +++ b/experiments/resnet18.py @@ -1,49 +1,28 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime -def run_resnet(batch, config): - from torchvision.models import resnet18 - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() - model = resnet18().eval() - input = torch.randn(batch, 3, 224, 224).to(device=device) - opt_fn = torch.compile(dynamic=False)(model.to(device, memory_format=torch.channels_last)) +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) - SchedulerDNNModel.register_model("resnet18", opt_fn) - request = Request("resnet18", [input], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) +import torch +from torchvision.models import resnet18 +from Simulator.simulator import TOGSimulator - # Run scheduler - while not scheduler.is_finished(): - with torch.no_grad(): - scheduler.schedule() - - print("ResNet18 Simulation Done") +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml') +os.environ['TOGSIM_CONFIG'] = config if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) args = argparse.ArgumentParser() args.add_argument('--batch', type=int, default=1) - args.add_argument('--dump_path', type=str, default='results') args = args.parse_args() - batch = args.batch - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"resnet18_{batch}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_DUMP_PATH'] = result_path - os.environ['TORCHSIM_USE_TIMING_POOLING'] = "1" - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] - run_resnet(batch, config) + device = torch.device("npu:0") + model = resnet18().eval().to(device=device, memory_format=torch.channels_last) + opt_fn = torch.compile(dynamic=False)(model) + model_input = torch.randn(args.batch, 3, 224, 224).to(device=device) + + with TOGSimulator(config_path=config), torch.no_grad(): + torch.npu.launch_model(opt_fn, model_input, stream_index=0, timestamp=0) + torch.npu.synchronize() + print("ResNet18 Simulation Done") diff --git a/experiments/resnet50.py b/experiments/resnet50.py index 4f03ea15..d886c159 100644 --- a/experiments/resnet50.py +++ b/experiments/resnet50.py @@ -1,49 +1,28 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime -def run_resnet(batch, config): - from torchvision.models import resnet50 - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() - model = resnet50().eval() - input = torch.randn(batch, 3, 224, 224).to(device=device) - opt_fn = torch.compile(dynamic=False)(model.to(device, memory_format=torch.channels_last)) +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) - SchedulerDNNModel.register_model("resnet50", opt_fn) - request = Request("resnet50", [input], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) +import torch +from torchvision.models import resnet50 +from Simulator.simulator import TOGSimulator - # Run scheduler - while not scheduler.is_finished(): - with torch.no_grad(): - scheduler.schedule() - - print("ResNet50 Simulation Done") +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') +os.environ['TOGSIM_CONFIG'] = config if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) args = argparse.ArgumentParser() args.add_argument('--batch', type=int, default=1) - args.add_argument('--dump_path', type=str, default='results') args = args.parse_args() - batch = args.batch - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"resnet50_{batch}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_DUMP_PATH'] = result_path - os.environ['TORCHSIM_USE_TIMING_POOLING'] = "1" - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] - run_resnet(batch, config) + device = torch.device("npu:0") + model = resnet50().eval().to(device=device, memory_format=torch.channels_last) + opt_fn = torch.compile(dynamic=False)(model) + model_input = torch.randn(args.batch, 3, 224, 224).to(device=device) + + with TOGSimulator(config_path=config), torch.no_grad(): + torch.npu.launch_model(opt_fn, model_input, stream_index=0, timestamp=0) + torch.npu.synchronize() + print("ResNet50 Simulation Done") diff --git a/experiments/softmax.py b/experiments/softmax.py index b47bd685..05024121 100644 --- a/experiments/softmax.py +++ b/experiments/softmax.py @@ -1,47 +1,29 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime - -def run_softmax(size, config, dim=1): - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() - input = torch.randn(size).to(device=device) - opt_fn = torch.compile(dynamic=False)(torch.nn.Softmax(dim=dim).to(device=device)) +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) - SchedulerDNNModel.register_model("Softmax", opt_fn) - request = Request("Softmax", [input], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) - - # Run scheduler - while not scheduler.is_finished(): - scheduler.schedule() +import torch +from Simulator.simulator import TOGSimulator - print(f"Softmax {str(size)} Simulation Done") +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') +os.environ['TOGSIM_CONFIG'] = config if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) args = argparse.ArgumentParser() args.add_argument('--size', nargs='+', type=int, default=[512, 512], help='Tensor Shape') - args.add_argument('--dump_path', type=str, default='results') args = args.parse_args() - size = args.size - size_str = "x".join([str(i) for i in size]) - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"Softmax_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_DUMP_PATH'] = result_path - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] - - run_softmax(size, config) + size = tuple(args.size) + dim = 1 + + device = torch.device("npu:0") + model = torch.nn.Softmax(dim=dim).to(device=device) + opt_fn = torch.compile(dynamic=False)(model) + model_input = torch.randn(*size).to(device=device) + + with TOGSimulator(config_path=config), torch.no_grad(): + torch.npu.launch_model(opt_fn, model_input, stream_index=0, timestamp=0) + torch.npu.synchronize() + print(f"Softmax {size} Simulation Done") diff --git a/scripts/CompilerOpt_experiment/DMAopt.sh b/scripts/CompilerOpt_experiment/DMAopt.sh index 5c2dc65c..9f3a9df2 100644 --- a/scripts/CompilerOpt_experiment/DMAopt.sh +++ b/scripts/CompilerOpt_experiment/DMAopt.sh @@ -1,5 +1,5 @@ #!/bin/bash -export TORCHSIM_CONFIG="/root/workspace/PyTorchSim/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json" +export TOGSIM_CONFIG="/root/workspace/PyTorchSim/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.yml" # None FG DMA export TORCHSIM_SUBTILE=0 diff --git a/scripts/ILS_experiment/test_matmul.py b/scripts/ILS_experiment/test_matmul.py index 667dfc66..b0bc474c 100644 --- a/scripts/ILS_experiment/test_matmul.py +++ b/scripts/ILS_experiment/test_matmul.py @@ -52,15 +52,9 @@ def custom_matmul(bias, a, b): test_result("Addmm Forward", res, y) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) parser = argparse.ArgumentParser(description="Run matmul with given shape") parser.add_argument('--shape', type=str, default="(512,512,512)") args = parser.parse_args() shape = tuple(map(int, args.shape.strip('()').split(','))) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() - test_matmul(device, *shape) + device = torch.device("npu:0") + test_matmul(device, *shape) \ No newline at end of file diff --git a/scripts/chiplet.sh b/scripts/chiplet.sh index 0d56ecae..e622874b 100755 --- a/scripts/chiplet.sh +++ b/scripts/chiplet.sh @@ -19,11 +19,11 @@ GEMM_DIR_NAME=$(basename "$GEMM_PATH") echo "GEMM Directory Name: $GEMM_DIR_NAME" CONFIG_LIST=( - "$TORCHSIM_DIR/configs/systolic_ws_128x128_c2_chiplet_tpuv3.json" + "$TORCHSIM_DIR/configs/systolic_ws_128x128_c2_chiplet_tpuv3.yml" ) CONFIG_LIST2=( - "$TORCHSIM_DIR/configs/systolic_ws_128x128_c2_booksim_tpuv3.json" - "$TORCHSIM_DIR/configs/systolic_ws_128x128_c2_chiplet_tpuv3_xnuma.json" + "$TORCHSIM_DIR/configs/systolic_ws_128x128_c2_booksim_tpuv3.yml" + "$TORCHSIM_DIR/configs/systolic_ws_128x128_c2_chiplet_tpuv3_xnuma.yml" ) shift shift @@ -39,7 +39,7 @@ MODELS_LIST="$GEMM_PATH/tile_graph.onnx" ATTRIBUTE_PATH="$GEMM_PATH/runtime_0000/attribute" for CONFIG in "${CONFIG_LIST[@]}"; do - CONFIG_NAME=$(basename "$CONFIG" .json) + CONFIG_NAME=$(basename "$CONFIG" .yml) for ATTRIBUTE_FILE in "${ATTRIBUTE_FILES[@]}"; do ATTRIBUTE_NAME=$(basename "$ATTRIBUTE_FILE") @@ -56,7 +56,7 @@ for CONFIG in "${CONFIG_LIST[@]}"; do done for CONFIG in "${CONFIG_LIST2[@]}"; do - CONFIG_NAME=$(basename "$CONFIG" .json) + CONFIG_NAME=$(basename "$CONFIG" .yml) ATTRIBUTE_NAME=0 RESULTS_DIR="./chiplet_results$INDEX_NAME/$GEMM_DIR_NAME/$ATTRIBUTE_NAME" mkdir -p "$RESULTS_DIR" diff --git a/scripts/chiplet_prep.py b/scripts/chiplet_prep.py index 32f7ad50..2266d74c 100644 --- a/scripts/chiplet_prep.py +++ b/scripts/chiplet_prep.py @@ -1,10 +1,7 @@ import os -import json -import shutil +import yaml import argparse import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -41,9 +38,11 @@ def modify_file(dump_path, name, address_numa_stride=None, subgraph_map=None): if not os.path.exists(file_path): print(f"File {file_path} does not exist.") return + with open(file_path, 'r') as f: - data = json.load(f) - # address_numa_stride와 subgraph_map 추가 + data = yaml.safe_load(f) + + # address_numa_stride, subgraph_map if address_numa_stride: data['address_numa_stride'] = address_numa_stride if subgraph_map: @@ -52,25 +51,20 @@ def modify_file(dump_path, name, address_numa_stride=None, subgraph_map=None): output_path = file_path = os.path.join(dump_path, 'runtime_0000', 'attribute') os.makedirs(output_path, exist_ok=True) output_file = os.path.join(output_path, name) + with open(output_file, 'w') as f: - json.dump(data, f, indent=4) + yaml.dump(data, f, default_flow_style=False, sort_keys=False) print(f"Modified file saved to {output_file}") if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") parser = argparse.ArgumentParser(description='Process folder argument.') parser.add_argument('size', type=int, help='Folder value', default=256) args = parser.parse_args() folder = int(args.size) print("Taget size: ", folder) - folder_path = os.environ.get("TORCHSIM_DUMP_PATH") + folder_path = os.environ.get("TORCHSIM_LOG_PATH") print(folder_path) os.makedirs(folder_path, exist_ok=True) test_matmul(device, folder, folder, folder) diff --git a/scripts/chiplet_prep.sh b/scripts/chiplet_prep.sh index cddf1a58..f3bd1a1c 100755 --- a/scripts/chiplet_prep.sh +++ b/scripts/chiplet_prep.sh @@ -8,7 +8,7 @@ for size in "${sizes[@]}"; do export TORCHSIM_TILE_M=$((size / 2)) export TORCHSIM_TILE_K=$((size / 2)) export TORCHSIM_TILE_N=$((size / 2)) - export TORCHSIM_DUMP_PATH=$(pwd)/chiplet_result/$size + export TORCHSIM_LOG_PATH=$(pwd)/chiplet_result/$size python3 chiplet_prep.py $size #python3 chiplet_run.py $(pwd)/chiplet_result done \ No newline at end of file diff --git a/scripts/ci/thirdparty_base_pin.sh b/scripts/ci/thirdparty_base_pin.sh new file mode 100755 index 00000000..6cfc7d9a --- /dev/null +++ b/scripts/ci/thirdparty_base_pin.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash +# Deterministic short pin for tagging torchsim_base images (thirdparty + base Dockerfile). +set -euo pipefail +ROOT="$(cd "$(dirname "$0")/../.." && pwd)" +cd "$ROOT" +{ cat thirdparty/github-releases.json; cat Dockerfile.base; } | sha256sum | awk '{print substr($1,1,12)}' diff --git a/scripts/ci/thirdparty_github_asset_env.sh b/scripts/ci/thirdparty_github_asset_env.sh new file mode 100755 index 00000000..8cbe9e12 --- /dev/null +++ b/scripts/ci/thirdparty_github_asset_env.sh @@ -0,0 +1,54 @@ +#!/usr/bin/env bash +# Emit GEM5_ASSET_ID, LLVM_ASSET_ID, SPIKE_ASSET_ID lines for appending to GITHUB_ENV. +# Requires: jq, curl, GITHUB_TOKEN, repo root as cwd or GITHUB_WORKSPACE. +set -euo pipefail +ROOT="${GITHUB_WORKSPACE:-$(cd "$(dirname "$0")/../.." && pwd)}" +MANIFEST="${ROOT}/thirdparty/github-releases.json" +if [ ! -f "$MANIFEST" ]; then + echo "Missing thirdparty manifest: $MANIFEST" >&2 + exit 1 +fi +if [ -z "${GITHUB_TOKEN:-}" ]; then + echo "GITHUB_TOKEN is not set" >&2 + exit 1 +fi + +thirdparty_asset_id() { + local key="$1" + local out_var="$2" + local repo release_tag asset_name owner name api_url tmp id + repo=$(jq -r --arg k "$key" '.[$k].repository' "$MANIFEST") + release_tag=$(jq -r --arg k "$key" '.[$k].release_tag' "$MANIFEST") + asset_name=$(jq -r --arg k "$key" '.[$k].asset_name // ""' "$MANIFEST") + owner="${repo%%/*}" + name="${repo##*/}" + if [ "$release_tag" = "latest" ]; then + api_url="https://api.github.com/repos/${owner}/${name}/releases/latest" + else + api_url="https://api.github.com/repos/${owner}/${name}/releases/tags/${release_tag}" + fi + tmp=$(mktemp) + if ! curl -fsS -H "Authorization: Bearer ${GITHUB_TOKEN}" \ + -H "Accept: application/vnd.github+json" \ + -H "X-GitHub-Api-Version: 2022-11-28" \ + "$api_url" -o "$tmp"; then + echo "Failed to fetch release metadata for ${key} (${owner}/${name}, ${release_tag})" >&2 + rm -f "$tmp" + exit 1 + fi + if [ -n "$asset_name" ]; then + id=$(jq -r --arg n "$asset_name" '.assets[] | select(.name == $n) | .id' "$tmp" | head -n1) + else + id=$(jq -r '.assets[0].id' "$tmp") + fi + rm -f "$tmp" + if [ -z "$id" ] || [ "$id" = "null" ]; then + echo "Could not resolve asset id for ${key} (${owner}/${name}, tag=${release_tag}, asset_name=${asset_name:-})" >&2 + exit 1 + fi + echo "${out_var}=${id}" +} + +thirdparty_asset_id gem5 GEM5_ASSET_ID +thirdparty_asset_id llvm_project LLVM_ASSET_ID +thirdparty_asset_id spike SPIKE_ASSET_ID diff --git a/scripts/sparsity_experiment/run.sh b/scripts/sparsity_experiment/run.sh index 4f5dd3a6..7996b5ab 100755 --- a/scripts/sparsity_experiment/run.sh +++ b/scripts/sparsity_experiment/run.sh @@ -1,11 +1,11 @@ -export TORCHSIM_DUMP_PATH=$(pwd)/result +export TORCHSIM_LOG_PATH=$(pwd)/result export SPIKE_DUMP_SPARSE_TILE=1 export TORCHSIM_FORCE_TIME_K=8 export TORCHSIM_FORCE_TIME_M=8 export TORCHSIM_FORCE_TIME_N=8 OUTPUT_DIR="12GB" -export TORCHSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c1_12G_simple_noc.json" +export TOGSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c1_12G_simple_noc.yml" python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.0 > ${OUTPUT_DIR}/0.0 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.2 > ${OUTPUT_DIR}/0.2 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.4 > ${OUTPUT_DIR}/0.4 @@ -13,7 +13,7 @@ python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.6 > ${OUTPUT python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.8 > ${OUTPUT_DIR}/0.8 OUTPUT_DIR="24GB" -export TORCHSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c1_24G_simple_noc.json" +export TOGSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c1_24G_simple_noc.yml" python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.0 > ${OUTPUT_DIR}/0.0 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.2 > ${OUTPUT_DIR}/0.2 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.4 > ${OUTPUT_DIR}/0.4 @@ -21,7 +21,7 @@ python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.6 > ${OUTPUT python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.8 > ${OUTPUT_DIR}/0.8 OUTPUT_DIR="48GB" -export TORCHSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c1_48G_simple_noc.json" +export TOGSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c1_48G_simple_noc.yml" python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.0 > ${OUTPUT_DIR}/0.0 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.2 > ${OUTPUT_DIR}/0.2 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.4 > ${OUTPUT_DIR}/0.4 @@ -29,7 +29,7 @@ python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.6 > ${OUTPUT python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.8 > ${OUTPUT_DIR}/0.8 OUTPUT_DIR="12GB_2core" -export TORCHSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c2_12G_simple_noc.json" +export TOGSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c2_12G_simple_noc.yml" python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.0 > ${OUTPUT_DIR}/0.0 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.2 > ${OUTPUT_DIR}/0.2 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.4 > ${OUTPUT_DIR}/0.4 @@ -37,7 +37,7 @@ python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.6 > ${OUTPUT python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.8 > ${OUTPUT_DIR}/0.8 OUTPUT_DIR="24GB_2core" -export TORCHSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c2_24G_simple_noc.json" +export TOGSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c2_24G_simple_noc.yml" python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.0 > ${OUTPUT_DIR}/0.0 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.2 > ${OUTPUT_DIR}/0.2 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.4 > ${OUTPUT_DIR}/0.4 @@ -45,7 +45,7 @@ python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.6 > ${OUTPUT python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.8 > ${OUTPUT_DIR}/0.8 OUTPUT_DIR="48GB_2core" -export TORCHSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c2_48G_simple_noc.json" +export TOGSIM_CONFIG="/workspace/PyTorchSim/configs/systolic_ws_8x8_c2_48G_simple_noc.yml" python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.0 > ${OUTPUT_DIR}/0.0 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.2 > ${OUTPUT_DIR}/0.2 python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.4 > ${OUTPUT_DIR}/0.4 diff --git a/scripts/stonne_experiment/run.sh b/scripts/stonne_experiment/run.sh index 1825817f..2e386d9c 100755 --- a/scripts/stonne_experiment/run.sh +++ b/scripts/stonne_experiment/run.sh @@ -2,8 +2,8 @@ export TORCHSIM_FORCE_TIME_M=1024 export TORCHSIM_FORCE_TIME_K=1024 export TORCHSIM_FORCE_TIME_N=1024 -python3 ../../tests/test_hetro.py --M 1024 --N 1024 --K 1024 --sparsity 0.9 --config stonne_big_c1_simple_noc.json --mode 0 > hetero/big_sparse.log -python3 ../../tests/test_hetro.py --M 1024 --N 1024 --K 1024 --sparsity 0.9 --config systolic_ws_128x128_c1_simple_noc_tpuv3_half.json --mode 1 > hetero/big.log -python3 ../../tests/test_hetro.py --M 1024 --N 1024 --K 1024 --sparsity 0.9 --config heterogeneous_c2_simple_noc.json --mode 2 > hetero/hetero.log +python3 ../../tests/test_hetro.py --M 1024 --N 1024 --K 1024 --sparsity 0.9 --config stonne_big_c1_simple_noc.yml --mode 0 > hetero/big_sparse.log +python3 ../../tests/test_hetro.py --M 1024 --N 1024 --K 1024 --sparsity 0.9 --config systolic_ws_128x128_c1_simple_noc_tpuv3_half.yml --mode 1 > hetero/big.log +python3 ../../tests/test_hetro.py --M 1024 --N 1024 --K 1024 --sparsity 0.9 --config heterogeneous_c2_simple_noc.yml --mode 2 > hetero/hetero.log echo "All processes completed!" diff --git a/scripts/stonne_experiment2/tog_gen.py b/scripts/stonne_experiment2/tog_gen.py index d4f93d4d..0e4b5812 100644 --- a/scripts/stonne_experiment2/tog_gen.py +++ b/scripts/stonne_experiment2/tog_gen.py @@ -71,10 +71,8 @@ def extract_simulation_stats(result_path): if "outerPro" in path: continue tog_path = os.path.join(path, "tile_graph.onnx") - togsim_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "TOGSim") - stonne_config_path = f'{extension_config.CONFIG_TORCHSIM_DIR}/configs/stonne_validation_c1_simple_noc.json' - backsim = TOGSimulator(togsim_path, stonne_config_path) - result_path = backsim.simulation(tog_path) + stonne_config_path = f'{extension_config.CONFIG_TORCHSIM_DIR}/configs/stonne_validation_c1_simple_noc.yml' + result_path = TOGSimulator.run_standalone(tog_path, config_path=stonne_config_path) nr_multiplications, total_cycle, sim_time = extract_simulation_stats(result_path) sim_time, total_cycle = float(sim_time), int(total_cycle) print(f"[TLS] Cycle={total_cycle} Sim time={sim_time} nr_multiplications={nr_multiplications}") diff --git a/tests/DeepSeek/test_deepseek_v3_base.py b/tests/DeepSeek/test_deepseek_v3_base.py new file mode 100644 index 00000000..ade787c5 --- /dev/null +++ b/tests/DeepSeek/test_deepseek_v3_base.py @@ -0,0 +1,330 @@ +import os +import sys +import argparse +import copy +from pathlib import Path +import torch + +# recursive compile for some ops that are caused by graph break +torch.npu.register_eager_to_compile([ + "aten::zero_", + "aten::sum.IntList_out", + "aten::mul.out", + "aten::floor_divide", + "aten::floor_divide.Tensor", + "aten::floor_divide.Scalar", + "aten::cat.out", + "aten::sort.values_stable", +]) + + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + out_cpu = out.cpu() + max_diff = (out_cpu - cpu_out).abs().max().item() + mean_diff = (out_cpu - cpu_out).abs().mean().item() + if torch.allclose(out_cpu, cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print(f"Max absolute difference: {max_diff:.6f}") + print(f"Mean absolute difference: {mean_diff:.6f}") + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("NPU out: ", out_cpu) + print("CPU out: ", cpu_out) + print(f"Max absolute difference: {max_diff:.6f}") + print(f"Mean absolute difference: {mean_diff:.6f}") + exit(1) + + +def _extract_logits(output): + if isinstance(output, torch.Tensor): + return output + if hasattr(output, "logits"): + return output.logits + if isinstance(output, (list, tuple)) and len(output) > 0 and isinstance(output[0], torch.Tensor): + return output[0] + raise TypeError(f"Unsupported output type for comparison: {type(output)}") + + +def _dtype_from_str(name: str) -> torch.dtype: + return { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }.get(name, torch.float32) + + +def _build_random_inputs(batch, seq_len, vocab_size, device): + g = torch.Generator().manual_seed(0) + input_ids = torch.randint(0, vocab_size, (batch, seq_len), generator=g, dtype=torch.int64) + return input_ids.to(device) + + +def _safe_scaled_int(value, scale, min_value=1): + return max(min_value, int(round(float(value) * float(scale)))) + + +def _round_to_multiple(value, multiple, min_value=1): + if multiple is None or multiple <= 0: + return max(min_value, int(value)) + v = max(min_value, int(value)) + return max(min_value, ((v + multiple - 1) // multiple) * multiple) + + +def _maybe_scale_config(config, scale=1.0, max_layers=None): + if scale == 1.0 and max_layers is None: + return config + + if hasattr(config, "hidden_size"): + config.hidden_size = _safe_scaled_int(config.hidden_size, scale) + if hasattr(config, "intermediate_size"): + config.intermediate_size = _safe_scaled_int(config.intermediate_size, scale) + if hasattr(config, "num_hidden_layers"): + config.num_hidden_layers = _safe_scaled_int(config.num_hidden_layers, scale) + if hasattr(config, "num_attention_heads"): + config.num_attention_heads = _safe_scaled_int(config.num_attention_heads, scale) + if hasattr(config, "num_key_value_heads"): + config.num_key_value_heads = min( + _safe_scaled_int(config.num_key_value_heads, scale), + config.num_attention_heads, + ) + + for name in [ + "n_routed_experts", + "n_shared_experts", + "num_local_experts", + "num_experts", + "num_experts_per_tok", + "moe_intermediate_size", + "shared_expert_intermediate_size", + ]: + if hasattr(config, name): + setattr(config, name, _safe_scaled_int(getattr(config, name), scale)) + + # DeepSeek MoE gate expects n_routed_experts to be divisible by n_group. + if hasattr(config, "n_routed_experts") and hasattr(config, "n_group"): + config.n_routed_experts = _round_to_multiple( + config.n_routed_experts, + config.n_group, + min_value=max(1, int(config.n_group)), + ) + + if max_layers is not None and hasattr(config, "num_hidden_layers"): + config.num_hidden_layers = max(1, min(int(max_layers), int(config.num_hidden_layers))) + + if hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"): + config.hidden_size = max( + config.num_attention_heads, + (config.hidden_size // config.num_attention_heads) * config.num_attention_heads, + ) + + return config + + +def _apply_preset(scale, max_layers, batch, seq_len, preset): + if preset == "tiny": + return 0.03, 1, 1, min(seq_len, 16) + if preset == "small": + return 0.07, 8, 1, min(seq_len, 32) + if preset == "medium": + return 0.10, 12, 1, min(seq_len, 48) + return scale, max_layers, batch, seq_len + + +def _togsim_log_count() -> int: + log_dir = Path("togsim_results") + if not log_dir.exists(): + return 0 + return len(list(log_dir.glob("*.log"))) + + +def _assert_simulation_happened(before_count: int, case_name: str): + after_count = _togsim_log_count() + if after_count <= before_count: + raise RuntimeError( + f"{case_name}: TOGSim log count did not increase " + f"(before={before_count}, after={after_count})" + ) + print(f"{case_name}: TOGSim logs increased ({before_count} -> {after_count})") + + +def test_cat_default(device): + def cat_default_fn(a, b): + return torch.cat([a, b], dim=0) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_default_fn) + + before = _togsim_log_count() + out = opt_fn(x, y) + _assert_simulation_happened(before, "cat.default") + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + test_result("cat.default", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_out(device): + def cat_out_fn(a, b, out): + return torch.ops.aten.cat.out([a, b], 0, out=out) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + out_buf = torch.empty(14, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_out_fn) + + before = _togsim_log_count() + out = opt_fn(x, y, out_buf) + _assert_simulation_happened(before, "cat.out") + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + test_result("cat.out", out, cpu_out, rtol=1e-4, atol=1e-4) + + +@torch.no_grad() +def run_deepseek_v3_base( + model_id, + device, + init_mode="config-random", + scale=1.0, + max_layers=None, + dtype="float16", + batch=1, + seq_len=32, + use_tokenizer=False, + prompt="Hello, DeepSeek V3", + trust_remote_code=False, + revision=None, + compile_model=False, +): + from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + + torch_dtype = _dtype_from_str(dtype) + + # Load model config + config = AutoConfig.from_pretrained( + model_id, + trust_remote_code=trust_remote_code, + revision=revision, + ) + + # Some remote model codes expect quantization_config to stay object-like + # (call .to_dict()), so only disable it for pretrained loading path. + if init_mode == "pretrained" and getattr(config, "quantization_config", None) is not None: + config.quantization_config = None + config = _maybe_scale_config(config, scale=scale, max_layers=max_layers) + + if init_mode == "config-random": + model = AutoModelForCausalLM.from_config( + config=config, + trust_remote_code=trust_remote_code, + ).eval() + model = model.to(dtype=torch_dtype) + elif init_mode == "pretrained": + # Load model(weights) + model = AutoModelForCausalLM.from_pretrained( + model_id, + config=config, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + revision=revision, + ).eval() + else: + raise ValueError(f"Unsupported init mode: {init_mode}") + + model_params = sum(p.numel() for p in model.parameters()) + print("init mode:", init_mode) + print("scaled hidden_size:", getattr(config, "hidden_size", "n/a")) + print("scaled num_hidden_layers:", getattr(config, "num_hidden_layers", "n/a")) + print("scaled num_attention_heads:", getattr(config, "num_attention_heads", "n/a")) + print("model params:", model_params) + + # Load tokenizer + if use_tokenizer: + tokenizer = AutoTokenizer.from_pretrained( + model_id, + trust_remote_code=trust_remote_code, + revision=revision, + ) + encoded = tokenizer(prompt, return_tensors="pt") + cpu_input_ids = encoded["input_ids"].cpu() + else: + vocab_size = getattr(config, "vocab_size", None) + if vocab_size is None: + raise ValueError("Config has no vocab_size; use --use-tokenizer or pass a model with vocab_size.") + cpu_input_ids = _build_random_inputs(batch, seq_len, vocab_size, torch.device("cpu")) + input_ids = cpu_input_ids.to(device) + + # CPU version + model_cpu = copy.deepcopy(model).cpu().eval() + cpu_out = _extract_logits(model_cpu(cpu_input_ids)) + + # NPU version + model_npu = copy.deepcopy(model_cpu).to(device).eval() + if compile_model: + model_npu = torch.compile(model_npu, dynamic=False) + npu_out = _extract_logits(model_npu(input_ids)) + + # Campare results + test_result( + "DeepSeek V3 Base", + npu_out, + cpu_out, + rtol=3e-1, + atol=2e-1, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="DeepSeek V3 download-based test") + parser.add_argument("--model-id", type=str, default=os.environ.get("DEEPSEEK_V3_MODEL_ID", "deepseek-ai/DeepSeek-V3-Base")) + parser.add_argument("--revision", type=str, default=None) + parser.add_argument("--trust-remote-code", action="store_true", default=True) + parser.add_argument("--init-mode", type=str, default="config-random", choices=["config-random", "pretrained"]) + parser.add_argument("--preset", type=str, default="small", choices=["none", "tiny", "small", "medium"]) + parser.add_argument("--scale", type=float, default=1.0) + parser.add_argument("--max-layers", type=int, default=None) + parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"]) + parser.add_argument("--batch", type=int, default=1) + parser.add_argument("--seq-len", type=int, default=32) + parser.add_argument("--use-tokenizer", action="store_true") + parser.add_argument("--prompt", type=str, default="Hello, DeepSeek V3") + parser.add_argument("--compile", action="store_true", default=True) + parser.add_argument("--test", type=str, default="e2e", choices=["all", "e2e", "cat"]) + + args = parser.parse_args() + + if not args.model_id: + print("Error: --model-id is required (or set DEEPSEEK_V3_MODEL_ID).", file=sys.stderr) + sys.exit(2) + + args.scale, args.max_layers, args.batch, args.seq_len = _apply_preset( + args.scale, args.max_layers, args.batch, args.seq_len, args.preset + ) + + device = torch.device("npu:0") + + if args.test in ("all", "cat"): + test_cat_default(device) + test_cat_out(device) + if args.test in ("all", "e2e"): + run_deepseek_v3_base( + model_id=args.model_id, + device=device, + init_mode=args.init_mode, + scale=args.scale, + max_layers=args.max_layers, + dtype=args.dtype, + batch=args.batch, + seq_len=args.seq_len, + use_tokenizer=args.use_tokenizer, + prompt=args.prompt, + trust_remote_code=args.trust_remote_code, + revision=args.revision, + compile_model=args.compile, + ) diff --git a/tests/Diffusion/test_diffusion.py b/tests/Diffusion/test_diffusion.py index c5170209..85eaba9f 100644 --- a/tests/Diffusion/test_diffusion.py +++ b/tests/Diffusion/test_diffusion.py @@ -8,6 +8,7 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.upsampling import Upsample2D from diffusers.models.resnet import ResnetBlock2D +from diffusers.models.embeddings import Timesteps def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -313,7 +314,7 @@ def test_cross_attn_down_block2d( dual_cross_attention=False ): print(f"Testing CrossAttnDownBlock2D on device: {device}") - + # 1. Initialize the module on CPU cpu_block = CrossAttnDownBlock2D( in_channels=in_channels, @@ -338,7 +339,7 @@ def test_cross_attn_down_block2d( temb=temb_cpu, encoder_hidden_states=encoder_hidden_states_cpu, ) - + # 4. Initialize the module on the custom device device_block = cpu_block.to(device).eval() device_block = torch.compile(device_block, dynamic=False) @@ -347,7 +348,7 @@ def test_cross_attn_down_block2d( hidden_states_dev = hidden_states_cpu.to(device) temb_dev = temb_cpu.to(device) encoder_hidden_states_dev = encoder_hidden_states_cpu.to(device) - + # 6. Get the output from the custom device module with torch.no_grad(): dev_out, _ = device_block( @@ -442,9 +443,9 @@ def test_groupnorm( # 1. Initialize the module on CPU cpu_norm = torch.nn.GroupNorm( - num_groups=num_groups, - num_channels=channels, - eps=eps, + num_groups=num_groups, + num_channels=channels, + eps=eps, affine=True ).to("cpu").eval() @@ -462,13 +463,13 @@ def test_groupnorm( # 4. Initialize the module on the custom device device_norm = torch.nn.GroupNorm( - num_groups=num_groups, - num_channels=channels, - eps=eps, + num_groups=num_groups, + num_channels=channels, + eps=eps, affine=True ).to(device).eval() device_norm = torch.compile(device_norm, dynamic=False) - + # Copy the weights from the CPU module to ensure they are identical device_norm.weight.data.copy_(cpu_norm.weight.data) device_norm.bias.data.copy_(cpu_norm.bias.data) @@ -541,6 +542,89 @@ def test_upsample2d( print("Max diff >", torch.max(torch.abs(y_dev.cpu() - y_cpu)).item()) print("Upsample2D simulation done.") + +def test_flip_sin_to_cos_embedding( + device, + batch=1, + embedding_dim=256, + rtol=1e-4, + atol=1e-4, +): + def create_embeddings(timesteps, embedding_dim, scale=1.0, flip_sin_to_cos=False): + """ + Replicate the embedding creation logic from Timesteps class. + """ + half_dim = embedding_dim // 2 + exponent = -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / half_dim + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + emb = scale * emb + + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + # flip sine and cosine embeddings + if flip_sin_to_cos: + new_emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + return emb, new_emb + return emb, emb + + g = torch.Generator().manual_seed(0) + timesteps_cpu = torch.randint(low=0, high=1000, size=(batch,), generator=g, dtype=torch.long) + + # Test with flip_sin_to_cos=True + with torch.no_grad(): + emb_flip_cpu = create_embeddings(timesteps_cpu, embedding_dim, flip_sin_to_cos=True) + + # Move to device and test + timesteps_dev = timesteps_cpu.to(device) + @torch.compile(dynamic=False) + def create_embeddings_compiled(timesteps, embedding_dim, scale=1.0, flip_sin_to_cos=False): + return create_embeddings(timesteps, embedding_dim, scale, flip_sin_to_cos) + + with torch.no_grad(): + emb_flip_dev = create_embeddings_compiled(timesteps_dev, embedding_dim, flip_sin_to_cos=True) + + # Verify flip case + test_result("Embedding (flip_sin_to_cos=True)", emb_flip_dev[0], emb_flip_cpu[0], rtol=rtol, atol=atol) + print("Max diff (flip) >", torch.max(torch.abs(emb_flip_dev[0].cpu() - emb_flip_cpu[0])).item()) + test_result("Embedding (flip_sin_to_cos=True)", emb_flip_dev[1], emb_flip_cpu[1], rtol=rtol, atol=atol) + print("Max diff (flip) >", torch.max(torch.abs(emb_flip_dev[1].cpu() - emb_flip_cpu[1])).item()) + + +def test_timesteps( + device, + batch=1, + num_channels=64, + flip_sin_to_cos=True, + downscale_freq_shift=1.0, + rtol=1e-4, + atol=1e-4, +): + print(f"Testing Timesteps on device: {device}") + + cpu_timesteps = Timesteps( + num_channels=num_channels, + flip_sin_to_cos=flip_sin_to_cos, + downscale_freq_shift=downscale_freq_shift, + ).to("cpu").eval() + + g = torch.Generator().manual_seed(0) + timesteps_cpu = torch.randint(low=0, high=1000, size=(batch,), generator=g, dtype=torch.long) + + with torch.no_grad(): + cpu_out = cpu_timesteps(timesteps_cpu) + + dev_timesteps = cpu_timesteps.to(device).eval() + dev_timesteps = torch.compile(dev_timesteps, dynamic=False) + + timesteps_dev = timesteps_cpu.to(device) + with torch.no_grad(): + dev_out = dev_timesteps(timesteps_dev) + + test_result("Timesteps", dev_out, cpu_out, rtol=rtol, atol=atol) + print("Max diff >", torch.max(torch.abs(dev_out.cpu() - cpu_out)).item()) + print("Timesteps simulation done.") + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run UNet (diffusers) test with comparison") parser.add_argument("--model", type=str, default="runwayml/stable-diffusion-v1-5", @@ -553,18 +637,18 @@ def test_upsample2d( args = parser.parse_args() sys.path.append(os.environ.get("TORCHSIM_DIR", "/workspace/PyTorchSim")) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") #test_upsample2d(device) #test_groupnorm(device) #test_groupnorm(device, stride=[1, 1, 320*32, 320]) - #test_resnetblock2d(device, in_channels=640, out_channels=320, temb_channels=320) + #test_resnetblock2d(device, in_channels=640, out_channels=320, temb_channels=256, resnet_act_fn='silu') #test_resnetblock2d(device, in_channels=640, out_channels=320, temb_channels=1280) #test_cross_attn_down_block2d(device) #test_unet_mid_block2d_cross_attn(device) #test_cross_attn_up_block2d(device) + #test_flip_sin_to_cos_embedding(device) + #test_timesteps(device) test_unet2d_condition_model(device) #test_unet_conditional( # device=device, diff --git a/tests/Fusion/__init__.py b/tests/Fusion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/Fusion/test_addmm_residual.py b/tests/Fusion/test_addmm_residual.py index ef753a67..a2c17207 100644 --- a/tests/Fusion/test_addmm_residual.py +++ b/tests/Fusion/test_addmm_residual.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -39,13 +37,7 @@ def addmm_residual(a, b, c, d): test_result("Addmm + Residual Fusion Forward", res, y) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_addmm_residual(device, 32, 32, 32) test_addmm_residual(device, 128, 128, 128) test_addmm_residual(device, 512, 512, 512) diff --git a/tests/Fusion/test_attention_fusion.py b/tests/Fusion/test_attention_fusion.py index 123376d1..93a17347 100644 --- a/tests/Fusion/test_attention_fusion.py +++ b/tests/Fusion/test_attention_fusion.py @@ -1,8 +1,5 @@ -import math import copy import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -71,13 +68,7 @@ def test_MHA(device, num_heads=12, embed_dim=768, input_seq=512): test_result("MHA Forward", res, cpu_res) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_MHA(device) # test_Attention(device, head=16, seq=512, d_k=64) # test_MHA(device, num_heads=12, embed_dim=768) diff --git a/tests/Fusion/test_bmm_reduction.py b/tests/Fusion/test_bmm_reduction.py index 4f4d3ad6..45e31dab 100644 --- a/tests/Fusion/test_bmm_reduction.py +++ b/tests/Fusion/test_bmm_reduction.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -38,13 +36,7 @@ def bmm(a, b): test_result("BMM Reduction Fusion reduction", res[1], y[1]) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") #test_bmm_reduce(device) test_bmm_reduce(device, 12, 512) test_bmm_reduce(device, 4, 256) diff --git a/tests/Fusion/test_conv_fusion.py b/tests/Fusion/test_conv_fusion.py index 694f3bb9..bc200ff2 100644 --- a/tests/Fusion/test_conv_fusion.py +++ b/tests/Fusion/test_conv_fusion.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): message = f"|{name} Test Passed|" @@ -97,13 +95,7 @@ def custom_conv_bn_relu(a, b, bias, c, d, e, f): print("Max diff > ", torch.max(torch.abs(res.cpu() - out))) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") # Vanila test test_conv_residual(device, batch_size=3, in_channels=64, out_channels=64, input_size=28, kernel_size=3, stride=1, padding=1) diff --git a/tests/Fusion/test_matmul_activation.py b/tests/Fusion/test_matmul_activation.py index 2f1d014f..232ec98d 100644 --- a/tests/Fusion/test_matmul_activation.py +++ b/tests/Fusion/test_matmul_activation.py @@ -1,7 +1,5 @@ import copy import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -69,13 +67,7 @@ def test_matmul_activation(device, batch_size=16, input_size=32, output_size=8, print("CPU output > ", cpu_y) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_matmul_activation(device) test_matmul_activation(device, batch_size=32, input_size=32, output_size=32, activation_fn="sigmoid") test_matmul_activation(device, batch_size=42, input_size=42, output_size=42, activation_fn="sigmoid") diff --git a/tests/Fusion/test_matmul_reduction.py b/tests/Fusion/test_matmul_reduction.py index df8cf969..9b09214a 100644 --- a/tests/Fusion/test_matmul_reduction.py +++ b/tests/Fusion/test_matmul_reduction.py @@ -85,13 +85,7 @@ def matmul_fused(a, b, c, d): test_result("Matmul+residual+var_mean Fusion reduction", res[2], y[2]) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_matmul_reduce(device, 3072, 512, 768) test_matmul_var_mean(device) test_matmul_add_var_mean(device) diff --git a/tests/Fusion/test_matmul_scalar.py b/tests/Fusion/test_matmul_scalar.py index 0815bb90..d5a159ed 100644 --- a/tests/Fusion/test_matmul_scalar.py +++ b/tests/Fusion/test_matmul_scalar.py @@ -35,11 +35,5 @@ def matmul_fused(a, b, c): test_result("Matmul Scalar Fusion Forward", res, y) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_matmul_scalar(device) diff --git a/tests/Fusion/test_matmul_vector.py b/tests/Fusion/test_matmul_vector.py new file mode 100644 index 00000000..f87f9432 --- /dev/null +++ b/tests/Fusion/test_matmul_vector.py @@ -0,0 +1,44 @@ +import torch + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def test_matmul_vector(device, size=[56, 78, 239], dim=0): + def matmul_fused(a, b, c, d): + return torch.matmul(a, b) + c + d + torch.manual_seed(0) + input = torch.randn(size[:2]) + weight = torch.randn(size[1:]) + output_sz = [size[0], size[2]] + output_sz[dim]=1 + bias = torch.zeros(output_sz) + add = torch.zeros(output_sz) + x1 = input.to(device=device) + w1 = weight.to(device=device) + b1 = bias.to(device=device) + a1 = add.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + b2 = bias.to("cpu") + a2 = add.to("cpu") + opt_fn = torch.compile(dynamic=False)(matmul_fused) + res = opt_fn(x1, w1, a1, b1) + y = matmul_fused(x2, w2, a2, b2) + test_result("Matmul Vector Fusion Forward", res, y) + +if __name__ == "__main__": + device = torch.device("npu:0") + test_matmul_vector(device, size=[253, 123, 47], dim=0) + test_matmul_vector(device, size=[253, 123, 47], dim=1) \ No newline at end of file diff --git a/tests/Fusion/test_prologue_fusion.py b/tests/Fusion/test_prologue_fusion.py index b27312a9..ecfd5fbf 100644 --- a/tests/Fusion/test_prologue_fusion.py +++ b/tests/Fusion/test_prologue_fusion.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -84,13 +82,7 @@ def bmm(a, b, c, d): test_result("BMM Element-wise Fusion Forward", res, out) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_elem_broadcast_fusion(device) test_elem_fusion(device) test_elem_bmm_input_fusion(device, batch_size=4, m=512, n=512, k=64) diff --git a/tests/Fusion/test_transformer_fusion.py b/tests/Fusion/test_transformer_fusion.py index b1cceb2c..1581cd97 100644 --- a/tests/Fusion/test_transformer_fusion.py +++ b/tests/Fusion/test_transformer_fusion.py @@ -1,8 +1,6 @@ import math import copy import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -199,13 +197,7 @@ def test_EncoderBlock_validation(head=12, embed_dim=768, input_seq=512): test_result("Encoder Block Validation", res, origin_res) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") #test_MHA(device) test_EncoderBlock(device) # test_EncoderBlock_validation() diff --git a/tests/Llama/test_llama.py b/tests/Llama/test_llama.py new file mode 100644 index 00000000..5e87b8e7 --- /dev/null +++ b/tests/Llama/test_llama.py @@ -0,0 +1,393 @@ +import os +import sys +import argparse +import copy +import torch +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding, LlamaModel + +def test_result(name, out, ref, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), ref.cpu(), rtol=rtol, atol=atol): + msg = f"|{name} Test Passed|" + print("-" * len(msg)); print(msg); print("-" * len(msg)) + else: + msg = f"|{name} Test Failed|" + print("-" * len(msg)); print(msg); print("-" * len(msg)) + diff = (out.cpu().int() - ref.cpu().int()).abs().max().item() + print("device out:", out.detach().cpu()) + print("cpu ref :", ref.detach().cpu()) + print(f"Max abs diff: {diff}") + sys.exit(1) + +@torch.no_grad() +def run_rmsnorm_test( + device, + batch=1, + seq_len=32, + dtype="float32", + rtol=1e-3, + atol=1e-3, +): + print("\n[Running LlamaRMSNorm Test]") + dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} + torch_dtype = dtype_map.get(dtype, torch.float32) + + hidden_size = 4096 + eps = 1e-6 + + print(f"Building LlamaRMSNorm (hidden_size={hidden_size}, eps={eps})") + base_norm = LlamaRMSNorm(hidden_size=hidden_size, eps=eps).eval() + cpu_norm = copy.deepcopy(base_norm).eval() + + cpu_norm.to(dtype=torch_dtype, device="cpu") + model = base_norm.to(dtype=torch_dtype, device=device) + + g = torch.Generator().manual_seed(0) + hidden_states = torch.randn(batch, seq_len, hidden_size, generator=g, dtype=torch_dtype) + hs_dev = hidden_states.to(device) + + print("Compiling LlamaRMSNorm with torch.compile(...)") + compiled_norm = torch.compile(model, dynamic=False) + + out_cpu = cpu_norm(hidden_states) + out_dev = compiled_norm(hs_dev) + + test_result("LlamaRMSNorm forward", out_dev, out_cpu, rtol=rtol, atol=atol) + print("Max diff >", (out_dev.detach().cpu() - out_cpu.detach().cpu()).abs().max().item()) + + +@torch.no_grad() +def run_rotary_embedding_test( + device, + batch=1, + seq_len=32, + dtype="float32", + rtol=1e-3, + atol=1e-3, +): + print("\n[Running LlamaRotaryEmbedding Test]") + dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} + torch_dtype = dtype_map.get(dtype, torch.float32) + + hidden_size = 4096 + num_heads = 32 + head_dim = hidden_size // num_heads + + cfg = LlamaConfig( + _name_or_path="custom-llama", + architectures=["LlamaForCausalLM"], + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=4096, + initializer_range=0.02, + intermediate_size=11008, + max_position_embeddings=4096, + mlp_bias=False, + model_type="llama", + num_attention_heads=32, + num_hidden_layers=1, + num_key_value_heads=32, + pretraining_tp=1, + rms_norm_eps=1e-06, + rope_scaling=None, + rope_theta=10000.0, + tie_word_embeddings=True, + torch_dtype=dtype, + transformers_version="4.43.4", + use_cache=True, + vocab_size=8192, + _attn_implementation = "sdpa" + ) + # Pass dim explicitly to avoid config parsing issues + base_rope = LlamaRotaryEmbedding(dim=head_dim, max_position_embeddings=cfg.max_position_embeddings, base=cfg.rope_theta, config=cfg) + + cpu_rope = copy.deepcopy(base_rope) + + cpu_rope.to(device="cpu") + model = base_rope.to(device=device) + + g = torch.Generator().manual_seed(0) + value = torch.randn(batch, num_heads, seq_len, head_dim, generator=g, dtype=torch_dtype) + position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).expand(batch, -1) + + val_dev = value.to(device) + pos_dev = position_ids.to(device) + + print("Compiling LlamaRotaryEmbedding with torch.compile(...)") + compiled_rope = torch.compile(model, dynamic=False) + + cos_cpu, sin_cpu = cpu_rope(value, position_ids) + cos_dev, sin_dev = compiled_rope(val_dev, pos_dev) + + print(f"Output dtype check - CPU: {cos_cpu.dtype}, Device: {cos_dev.dtype}") + + test_result("LlamaRotaryEmbedding (Cos)", cos_dev, cos_cpu, rtol=rtol, atol=atol) + test_result("LlamaRotaryEmbedding (Sin)", sin_dev, sin_cpu, rtol=rtol, atol=atol) + + diff_cos = (cos_dev.detach().cpu() - cos_cpu.detach().cpu()).abs().max().item() + diff_sin = (sin_dev.detach().cpu() - sin_cpu.detach().cpu()).abs().max().item() + print(f"Max diff (Cos) > {diff_cos}") + print(f"Max diff (Sin) > {diff_sin}") + +@torch.no_grad() +def run_decoder_layer_test( + device, + batch=1, + seq_len=32, + dtype="float32", + rtol=1e-3, + atol=1e-3, +): + print("\n[Running LlamaDecoderLayer Test]") + dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} + torch_dtype = dtype_map.get(dtype, torch.float32) + + cfg = LlamaConfig( + _name_or_path="custom-llama", + architectures=["LlamaForCausalLM"], + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=4096, + initializer_range=0.02, + intermediate_size=11008, + max_position_embeddings=4096, + mlp_bias=False, + model_type="llama", + num_attention_heads=32, + num_hidden_layers=1, + num_key_value_heads=32, + pretraining_tp=1, + rms_norm_eps=1e-06, + rope_scaling=None, + rope_theta=10000.0, + tie_word_embeddings=True, + torch_dtype=dtype, + transformers_version="4.43.4", + use_cache=True, + vocab_size=8192, + _attn_implementation = "sdpa" + ) + + print("Building LlamaDecoderLayer from custom config.") + base_layer = LlamaDecoderLayer(cfg, layer_idx=0).eval() + cpu_layer = copy.deepcopy(base_layer).eval() + + cpu_layer.to(dtype=torch_dtype, device="cpu") + model = base_layer.to(dtype=torch_dtype, device=device) + + g = torch.Generator().manual_seed(0) + hidden_states = torch.randn(batch, seq_len, cfg.hidden_size, generator=g, dtype=torch_dtype) + position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).expand(batch, -1) + + attention_mask = torch.zeros(batch, 1, seq_len, seq_len, dtype=torch_dtype) + mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1) + attention_mask.masked_fill_(mask, torch.finfo(torch_dtype).min) + + # Shape: (1, seq_len, head_dim) or (batch, seq_len, head_dim) + head_dim = cfg.hidden_size // cfg.num_attention_heads + cos = torch.randn(1, seq_len, head_dim, generator=g, dtype=torch_dtype) + sin = torch.randn(1, seq_len, head_dim, generator=g, dtype=torch_dtype) + position_embeddings = (cos, sin) + + hs_dev = hidden_states.to(device) + pos_dev = position_ids.to(device) + att_dev = attention_mask.to(device) + pos_emb_dev = (cos.to(device), sin.to(device)) + + print("Compiling LlamaDecoderLayer with torch.compile(...)") + compiled_layer = torch.compile(model, dynamic=False) + + out_cpu = cpu_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + position_embeddings=position_embeddings + ) + if isinstance(out_cpu, tuple): + out_cpu = out_cpu[0] + + out_dev = compiled_layer( + hidden_states=hs_dev, + attention_mask=att_dev, + position_ids=pos_dev, + position_embeddings=pos_emb_dev + ) + if isinstance(out_dev, tuple): + out_dev = out_dev[0] + + test_result("LlamaDecoderLayer forward", out_dev, out_cpu, rtol=rtol, atol=atol) + print("Max diff >", (out_dev.detach().cpu() - out_cpu.detach().cpu()).abs().max().item()) + +@torch.no_grad() +def run_custom_llama_test( + device, + batch=1, + seq_len=32, + dtype="float32", + rtol=1e-3, + atol=1e-3, + max_new_tokens=16, +): + dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} + torch_dtype = dtype_map.get(dtype, torch.float32) + + cfg = LlamaConfig( + _name_or_path="custom-llama", + architectures=["LlamaForCausalLM"], + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=11008, + max_position_embeddings=4096, + mlp_bias=False, + model_type="llama", + num_attention_heads=32, + num_hidden_layers=1, + num_key_value_heads=32, + pretraining_tp=1, + rms_norm_eps=1e-06, + rope_scaling=None, + rope_theta=10000.0, + tie_word_embeddings=True, + torch_dtype=dtype, + transformers_version="4.43.4", + use_cache=True, + vocab_size=8192, + ) + + print("Building LlamaForCausalLM from custom config (random init).") + base_model = LlamaForCausalLM(cfg).eval() + cpu_model = copy.deepcopy(base_model).eval() + + cpu_model.to(dtype=torch_dtype, device="cpu") + model = base_model.to(dtype=torch_dtype, device=device) + + g = torch.Generator().manual_seed(0) + vocab = cfg.vocab_size + input_ids_cpu = torch.randint(low=0, high=vocab, size=(batch, seq_len), generator=g, dtype=torch.long) + + min_dtype = torch.finfo(torch_dtype).min + causal_mask = torch.zeros((seq_len, seq_len), dtype=torch_dtype, device="cpu") + + if seq_len > 1: + causal_mask = torch.triu(torch.full_like(causal_mask, min_dtype), diagonal=1) + + cache_position = torch.arange(seq_len, device="cpu") + mask_condition = torch.arange(seq_len, device="cpu") > cache_position.reshape(-1, 1) + causal_mask.masked_fill_(mask_condition, min_dtype) + attn_mask_cpu = causal_mask[None, None, :, :].expand(batch, 1, -1, -1) + + input_ids_dev = input_ids_cpu.to(device) + attn_mask_dev = attn_mask_cpu.to(device) + + # ---- forward comparison (compile vs CPU baseline) ---- + print("Compiling model with torch.compile(...)") + compiled = torch.compile(model, dynamic=False) + + logits_cpu = cpu_model(input_ids=input_ids_cpu, attention_mask=attn_mask_cpu)#.logits + logits_dev = compiled(input_ids=input_ids_dev, attention_mask=attn_mask_dev)#.logits + + test_result("Custom Llama forward(logits)", logits_dev, logits_cpu, rtol=rtol, atol=atol) + print("Max diff >", (logits_dev.detach().cpu() - logits_cpu.detach().cpu()).abs().max().item()) + +@torch.no_grad() +def run_llama_model_test( + device, + batch=1, + seq_len=32, + dtype="float32", + rtol=1e-3, + atol=1e-3, +): + print("\n[Running LlamaModel Test]") + dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} + torch_dtype = dtype_map.get(dtype, torch.float32) + + cfg = LlamaConfig( + vocab_size=8192, + hidden_size=1024, + num_attention_heads=32, + num_key_value_heads=32, + intermediate_size=11008 // 4, + num_hidden_layers=1, + max_position_embeddings=4096, + hidden_act="silu", + use_cache=False, + torch_dtype=dtype, + ) + + print("Building LlamaModel from custom config (random init).") + base_model = LlamaModel(cfg).eval() + cpu_model = copy.deepcopy(base_model).eval() + + cpu_model.to(dtype=torch_dtype, device="cpu") + model = base_model.to(dtype=torch_dtype, device=device) + + g = torch.Generator().manual_seed(0) + input_ids_cpu = torch.randint(low=0, high=cfg.vocab_size, size=(batch, seq_len), generator=g, dtype=torch.long) + + min_dtype = torch.finfo(torch_dtype).min + causal_mask = torch.full((seq_len, seq_len), fill_value=min_dtype, dtype=torch_dtype, device="cpu") + if seq_len > 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + attn_mask_cpu = causal_mask[None, None, :, :].expand(batch, 1, -1, -1) + + input_ids_dev = input_ids_cpu.to(device) + attn_mask_dev = attn_mask_cpu.to(device) + + print("Compiling LlamaModel with torch.compile(...)") + compiled_model = torch.compile(model, dynamic=False) + + out_cpu = cpu_model(input_ids=input_ids_cpu, attention_mask=attn_mask_cpu) + out_dev = compiled_model(input_ids=input_ids_dev, attention_mask=attn_mask_dev) + + last_hidden_state_cpu = out_cpu.last_hidden_state + last_hidden_state_dev = out_dev.last_hidden_state + + test_result("LlamaModel (last_hidden_state)", last_hidden_state_dev, last_hidden_state_cpu, rtol=rtol, atol=atol) + diff = (last_hidden_state_dev.detach().cpu() - last_hidden_state_cpu.detach().cpu()).abs().max().item() + print(f"Max diff > {diff}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test Custom Llama (random weights, no tokenizer)") + parser.add_argument("--batch", type=int, default=1) + parser.add_argument("--seq_len", type=int, default=32) + parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"]) + parser.add_argument("--rtol", type=float, default=1e-3) + parser.add_argument("--atol", type=float, default=1e-3) + parser.add_argument("--max_new_tokens", type=int, default=16) + args = parser.parse_args() + + sys.path.append(os.environ.get("PYTORCHSIM_ROOT_PATH", "/workspace/PyTorchSim")) + device = torch.device("npu:0") + #test_triu(device, size=(32, 128), diagonal=1) + torch.compiler.is_compiling = lambda: True # FIXME. How to fix this? + #run_rmsnorm_test(device) + #run_rotary_embedding_test(device) + run_decoder_layer_test( + device=device, + batch=args.batch, + seq_len=args.seq_len, + dtype=args.dtype, + rtol=args.rtol, + atol=args.atol, + ) + run_llama_model_test(device) + #run_custom_llama_test( + # device=device, + # batch=args.batch, + # seq_len=args.seq_len, + # dtype=args.dtype, + # rtol=args.rtol, + # atol=args.atol, + #) diff --git a/tests/MLP/test_mlp.py b/tests/MLP/test_mlp.py index 31bcefdf..c910729e 100644 --- a/tests/MLP/test_mlp.py +++ b/tests/MLP/test_mlp.py @@ -281,10 +281,8 @@ def train(model, device): return if __name__ == "__main__": - from Scheduler.scheduler import PyTorchSimRunner torch.set_printoptions(threshold=float('inf'), linewidth=600) - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_mlp(device) # test_train_mlp(device) diff --git a/tests/Mixtral_8x7B/test_attention.py b/tests/Mixtral_8x7B/test_attention.py index 6a7747f7..57760370 100644 --- a/tests/Mixtral_8x7B/test_attention.py +++ b/tests/Mixtral_8x7B/test_attention.py @@ -1,7 +1,5 @@ import copy import torch -import torch._dynamo -import torch.utils.cpp_extension from model import Transformer, TransformerBlock, ModelArgs, Attention, FeedForward, KVCache, RMSNorm, precompute_freqs_cis, sample def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): @@ -159,15 +157,9 @@ def test_rmsnorm(device, seq=32): test_result("RMSNorm", res, cpu_res) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() - test_rmsnorm(device, seq=1) - test_concat(device, size1=(1, 8, 64, 64), size2=(1,8,1,64), dim=2) + device = torch.device("npu:0") + #test_rmsnorm(device, seq=1) + #test_concat(device, size1=(1, 8, 64, 64), size2=(1,8,1,64), dim=2) test_decode(device, 32, 3) #test_attention(device) #test_ffn(device) diff --git a/tests/MoE/test_moe.py b/tests/MoE/test_moe.py index ae16f0b0..d4cd98f1 100644 --- a/tests/MoE/test_moe.py +++ b/tests/MoE/test_moe.py @@ -4,7 +4,6 @@ import copy import matplotlib.pyplot as plt - import torch import torch.nn as nn from torch.distributions.normal import Normal @@ -17,6 +16,20 @@ sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) +# FIXME. This is a Dynamo bug. Solution to avoid is_forward conflict during backward +def patch_metrics_context_update(): + """Patch MetricsContext.update to set overwrite=True by default.""" + from torch._dynamo.utils import get_metrics_context + ctx = get_metrics_context() + original_update = ctx.update + + def patched_update(values, overwrite=True): + """Patched version that sets overwrite=True by default.""" + return original_update(values, overwrite=True) + + # Patch the method + get_metrics_context().update = patched_update + def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): pass_message = f"|{name} Test Passed|" fail_message = f"|{name} Test Failed|" @@ -64,6 +77,7 @@ class SparseDispatcher(object): `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`. """ + @torch.compiler.disable(recursive=True) def __init__(self, num_experts, gates): """Create a SparseDispatcher.""" gates = gates.cpu() @@ -443,6 +457,7 @@ def test_moe(device): total_cpu_loss = cpu_loss + cpu_aux_loss total_loss.to(device) + patch_metrics_context_update() print("Backward Started!") total_loss.backward() total_cpu_loss.backward() @@ -469,6 +484,9 @@ def test_moe(device): print("\n") def train_moe(device): + # Patch CompileEventLogger to avoid metric conflicts + patch_metrics_context_update() + def perceptron(a, b, c): return a * b + c @@ -589,6 +607,9 @@ def weight_update(a, b, lr): plt.savefig('result.png') def train_moe_mnist(device): + # Patch CompileEventLogger to avoid metric conflicts + patch_metrics_context_update() + torch.manual_seed(0) batch_size = 32 input_size = 28*28 @@ -670,6 +691,9 @@ def train(model, device, train_loader, optimizer, epochs): plt.savefig(f'{name}_result.png') def train_moe_single_iteration(device, iter_idx, is_evaluation=0): + # Patch CompileEventLogger to avoid metric conflicts + patch_metrics_context_update() + # Training moe with mnist dataset for sinlge iteration torch.manual_seed(0) batch_size = 128 @@ -783,10 +807,8 @@ def evaluation(model, evaluation_loader): train(opt_model, train_loader) if __name__ == "__main__": - from Scheduler.scheduler import PyTorchSimRunner torch.set_printoptions(threshold=float('inf'), linewidth=600) - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_moe(device) # train_moe(device) diff --git a/tests/MobileNet/test_mobilenet.py b/tests/MobileNet/test_mobilenet.py new file mode 100644 index 00000000..966d479a --- /dev/null +++ b/tests/MobileNet/test_mobilenet.py @@ -0,0 +1,106 @@ +import argparse +import copy +import os + +import torch +import torch._dynamo +import torch.utils.cpp_extension +from torchvision.models import mobilenet_v2 + + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + + +def _mobilenet_v2(): + try: + from torchvision.models import MobileNet_V2_Weights + + return mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).cpu().eval() + except Exception: + return mobilenet_v2().cpu().eval() + + +def run_mobilenet(batch, config): + device = torch.device("npu:0") + + torch._dynamo.config.recompile_limit = 64 + torch._dynamo.config.cache_size_limit = 128 + + model = _mobilenet_v2() + imgsz = 224 + x = torch.randn(batch, 3, imgsz, imgsz) + + model_cpu = copy.deepcopy(model).cpu().eval() + x_cpu = copy.deepcopy(x).cpu() + y_cpu = model_cpu(x_cpu) + + model_npu = model_cpu.to(device).eval() + x_npu = copy.deepcopy(x).to(device) + compiled_model_npu = torch.compile(dynamic=False)(model_npu) + y_npu = compiled_model_npu(x_npu) + + if isinstance(y_cpu, (list, tuple)): + for i, (out_npu, out_cpu) in enumerate(zip(y_npu, y_cpu)): + test_result(f"MobileNet Output {i}", out_npu, out_cpu) + else: + test_result("MobileNet Output", y_npu, y_cpu) + + print("MobileNet Simulation Done") + + +def test_inverted_residual_module(device, batch=1, inp=32, oup=32, stride=1, expand_ratio=6, h=28, w=28): + from torchvision.models.mobilenetv2 import InvertedResidual + + torch.manual_seed(0) + + x = torch.randn(batch, inp, h, w) + + model_cpu = InvertedResidual(inp, oup, stride, expand_ratio).cpu().eval() + x_cpu = copy.deepcopy(x).cpu() + y_cpu = model_cpu(x_cpu) + + model_npu = model_cpu.to(device).eval() + x_npu = copy.deepcopy(x).to(device) + compiled_model_npu = torch.compile(dynamic=False)(model_npu) + y_npu = compiled_model_npu(x_npu) + + test_result("InvertedResidual Module", y_npu, y_cpu) + print("InvertedResidual Module Test Done") + + +if __name__ == "__main__": + base_dir = os.environ.get("TORCHSIM_DIR", default="/workspace/PyTorchSim") + config = os.environ.get( + "TOGSIM_CONFIG", + default=f"{base_dir}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml", + ) + args = argparse.ArgumentParser() + args.add_argument("--batch", type=int, default=1) + args.add_argument("--dump_path", type=str, default="results") + args = args.parse_args() + batch = args.batch + + device = torch.device("npu:0") + + # print("\n" + "=" * 80) + # print("Testing InvertedResidual Module") + # print("=" * 80) + # test_inverted_residual_module(device, batch=batch, inp=32, oup=32, stride=1, expand_ratio=6, h=28, w=28) + + print("\n" + "=" * 80) + print("Testing Full MobileNet V2 Model") + print("=" * 80) + run_mobilenet(batch, config) diff --git a/tests/Yolov5/test_yolov5.py b/tests/Yolov5/test_yolov5.py new file mode 100644 index 00000000..d98828bd --- /dev/null +++ b/tests/Yolov5/test_yolov5.py @@ -0,0 +1,283 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + +import argparse +import datetime + +import requests +from PIL import Image +from io import BytesIO +from torchvision import transforms + +import os +import shutil + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def run_yolo(batch, config): + import copy + + device = torch.device("npu:0") + + torch._dynamo.config.recompile_limit = 64 + torch._dynamo.config.cache_size_limit = 128 + + # Load model and prepare input + model = torch.hub.load("ultralytics/yolov5", "yolov5s").cpu().eval() + url = "https://ultralytics.com/images/zidane.jpg" + + response = requests.get(url) + img = Image.open(BytesIO(response.content)).convert("RGB") + + imgsz = 64 + transform = transforms.Compose([ + transforms.Resize((imgsz, imgsz)), + transforms.ToTensor(), + ]) + + x = transform(img).unsqueeze(0) # [1, 3, H, W] + + # CPU version + model_cpu = copy.deepcopy(model).cpu().eval() + x_cpu = copy.deepcopy(x).cpu() + y_cpu = model_cpu(x_cpu) + + # NPU version + model_npu = model_cpu.to(device).eval() + x_npu = copy.deepcopy(x).to(device) + compiled_model_npu = torch.compile(dynamic=False)(model_npu) + y_npu = compiled_model_npu(x_npu) + + # Compare results + # YOLOv5 output is typically a list or tensor, handle both cases + if isinstance(y_cpu, (list, tuple)): + for i, (out_npu, out_cpu) in enumerate(zip(y_npu, y_cpu)): + test_result(f"YOLOv5 Output {i}", out_npu, out_cpu) + else: + test_result("YOLOv5 Output", y_npu, y_cpu) + + print("Yolo Simulation Done") + + +def test_c3_module(device, batch=1, c1=64, c2=128, n=1, h=64, w=64): + import copy + import sys + + # Import C3 module from YOLOv5 + try: + # Load model first to ensure hub cache is populated + _ = torch.hub.load("ultralytics/yolov5", "yolov5s", pretrained=False) + + # Try to import from torch hub cache + hub_path = os.path.expanduser("~/.cache/torch/hub/ultralytics_yolov5_master") + if os.path.exists(hub_path): + sys.path.insert(0, hub_path) + # Import C3 module + from models.common import C3 # noqa: F401 + except Exception as e: + print(f"Warning: Could not import C3 module: {e}") + print("Skipping C3 module test") + return + + torch.manual_seed(0) + + # Create input tensor + x = torch.randn(batch, c1, h, w) + + # CPU version + model_cpu = C3(c1, c2, n=n, shortcut=True, g=1, e=0.5).cpu().eval() + x_cpu = copy.deepcopy(x).cpu() + y_cpu = model_cpu(x_cpu) + + # NPU version + model_npu = model_cpu.to(device).eval() + x_npu = copy.deepcopy(x).to(device) + compiled_model_npu = torch.compile(dynamic=False)(model_npu) + y_npu = compiled_model_npu(x_npu) + + # Compare results + if isinstance(y_cpu, (list, tuple)): + for i, (out_npu, out_cpu) in enumerate(zip(y_npu, y_cpu)): + test_result(f"C3 Output {i}", out_npu, out_cpu) + else: + test_result("C3 Output", y_npu, y_cpu) + print("C3 Module Test Done") + + +def test_bottleneck_module(device, batch=1, c1=64, c2=64, shortcut=True, g=1, e=0.5, h=16, w=16): + import copy + import sys + + # Import Bottleneck module from YOLOv5 + try: + # Load model first to ensure hub cache is populated + _ = torch.hub.load("ultralytics/yolov5", "yolov5s", pretrained=False) + + # Try to import from torch hub cache + hub_path = os.path.expanduser("~/.cache/torch/hub/ultralytics_yolov5_master") + if os.path.exists(hub_path): + sys.path.insert(0, hub_path) + # Import Bottleneck module + from models.common import Bottleneck # noqa: F401 + except Exception as e: + print(f"Warning: Could not import Bottleneck module: {e}") + print("Skipping Bottleneck module test") + return + + torch.manual_seed(0) + + # Create input tensor + x = torch.randn(batch, c1, h, w) + + # CPU version + model_cpu = Bottleneck(c1, c2, shortcut=shortcut, g=g, e=e).cpu().eval() + x_cpu = copy.deepcopy(x).cpu() + y_cpu = model_cpu(x_cpu) + + # NPU version + model_npu = model_cpu.to(device).eval() + x_npu = copy.deepcopy(x).to(device) + compiled_model_npu = torch.compile(dynamic=False)(model_npu) + y_npu = compiled_model_npu(x_npu) + + # Compare results + test_result("Bottleneck Module", y_npu, y_cpu) + print("Bottleneck Module Test Done") + + +def test_conv_module(device, batch=1, c1=32, c2=64, k=3, s=1, p=None, g=1, d=1, act=True, h=16, w=16): + import copy + import sys + + # Import Conv module from YOLOv5 + try: + # Load model first to ensure hub cache is populated + _ = torch.hub.load("ultralytics/yolov5", "yolov5s", pretrained=False) + + # Try to import from torch hub cache + hub_path = os.path.expanduser("~/.cache/torch/hub/ultralytics_yolov5_master") + if os.path.exists(hub_path): + sys.path.insert(0, hub_path) + # Import Conv module + from models.common import Conv # noqa: F401 + except Exception as e: + print(f"Warning: Could not import Conv module: {e}") + print("Skipping Conv module test") + return + + torch.manual_seed(0) + + # Create input tensor + x = torch.randn(batch, c1, h, w) + + # CPU version + model_cpu = Conv(c1, c2, k=k, s=s, p=p, g=g, d=d, act=act).cpu().eval() + x_cpu = copy.deepcopy(x).cpu() + y_cpu = model_cpu(x_cpu) + + # NPU version + model_npu = model_cpu.to(device).eval() + x_npu = copy.deepcopy(x).to(device) + compiled_model_npu = torch.compile(dynamic=False)(model_npu) + y_npu = compiled_model_npu(x_npu) + + # Compare results + test_result("Conv Module", y_npu, y_cpu) + print("Conv Module Test Done") + + +def test_concat_4d(device): + """ + Test concatenating 3 tensors along dimension 4 + Shapes: (1, 3, 4, 4, 2), (1, 3, 4, 4, 2), (1, 3, 4, 4, 81) + Result: (1, 3, 4, 4, 85) + """ + import copy + + torch.manual_seed(0) + + # Create 3 input tensors + x1 = torch.ones(1, 3, 4, 4, 2) + x2 = torch.ones(1, 3, 4, 4, 2) * 2 + x3 = torch.ones(1, 3, 4, 4, 81) * 3 + + # CPU version + x1_cpu = copy.deepcopy(x1).cpu() + x2_cpu = copy.deepcopy(x2).cpu() + x3_cpu = copy.deepcopy(x3).cpu() + y_cpu = torch.cat([x1_cpu, x2_cpu, x3_cpu], dim=4) + + # NPU version + x1_npu = copy.deepcopy(x1).to(device) + x2_npu = copy.deepcopy(x2).to(device) + x3_npu = copy.deepcopy(x3).to(device) + + def concat_fn(x1, x2, x3): + return torch.cat([x1, x2, x3], dim=4) + + compiled_concat = torch.compile(dynamic=False)(concat_fn) + y_npu = compiled_concat(x1_npu, x2_npu, x3_npu) + + # Compare results + test_result("Concat 4D", y_npu, y_cpu) + print(f"Output shape: {y_npu.shape}") + print("Concat 4D Test Done") + +if __name__ == "__main__": + + base_dir = os.environ.get("TORCHSIM_DIR", default="/workspace/PyTorchSim") + config = os.environ.get( + "TOGSIM_CONFIG", + default=f"{base_dir}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml", + ) + args = argparse.ArgumentParser() + args.add_argument("--batch", type=int, default=1) + args.add_argument("--dump_path", type=str, default="results") + args = args.parse_args() + batch = args.batch + + device = torch.device("npu:0") + + # Test Concat 4D + # print("=" * 80) + # print("Testing Concat 4D") + # print("=" * 80) + # test_concat_4d(device) + + # Test Conv module + # print("\n" + "=" * 80) + # print("Testing Conv Module") + # print("=" * 80) + # test_conv_module(device, batch=batch, c1=32, c2=32, k=1, s=1, p=None, g=1, d=1, act=False, h=16, w=16) + + # Test Bottleneck module + # print("\n" + "=" * 80) + # print("Testing Bottleneck Module") + # print("=" * 80) + # test_bottleneck_module(device, batch=batch, c1=32, c2=32, shortcut=True, g=1, e=0.5, h=16, w=16) + + # Test C3 module + # print("\n" + "=" * 80) + # print("Testing C3 Module") + # print("=" * 80) + # test_c3_module(device, batch=batch, c1=64, c2=64, n=1, h=16, w=16) + + # Test full YOLOv5 model + print("\n" + "=" * 80) + print("Testing Full YOLOv5 Model") + print("=" * 80) + run_yolo(batch, config) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_activation.py b/tests/test_activation.py index 575fc7e8..dacc102e 100644 --- a/tests/test_activation.py +++ b/tests/test_activation.py @@ -23,9 +23,10 @@ def test_ReLU(device, size=(128, 128)): input = torch.randn(size) x1 = input.to(device=device) x2 = input.to("cpu") - opt_fn = torch.compile(dynamic=False)(torch.nn.functional.relu) + ReLU = torch.nn.ReLU() + opt_fn = torch.compile(dynamic=False)(ReLU) y = opt_fn(x1) - cpu_y = torch.nn.functional.relu(x2) + cpu_y = ReLU(x2) test_result("ReLU", y, cpu_y) def test_GeLU(device, size=(128, 128), approximate='none'): @@ -78,19 +79,14 @@ def test_SwiGLU(device, size=(128, 128)): test_result("SwiGLU", y, cpu_y) if __name__ == "__main__": - import os - import sys import argparse - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") parser.add_argument('--shape', type=str, default="(512,768)") args = parser.parse_args() shape = tuple(map(int, args.shape.strip('()').split(','))) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_ReLU(device, (47, 10)) test_ReLU(device, (128, 128)) test_ReLU(device, (4071, 429)) diff --git a/tests/test_add.py b/tests/test_add.py index 118632d5..7a0d23d9 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -48,19 +48,14 @@ def vectoradd(a, b): test_result("VectorTensorAdd", res, out) if __name__ == "__main__": - import os - import sys import argparse - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") parser.add_argument('--shape', type=str, default="(512,768)") args = parser.parse_args() shape = tuple(map(int, args.shape.strip('()').split(','))) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_vectoradd(device, (1, 1)) test_vectoradd(device, (47, 10)) test_vectoradd(device, (128, 128)) diff --git a/tests/test_batchnorm.py b/tests/test_batchnorm.py index 251805f5..065c0870 100644 --- a/tests/test_batchnorm.py +++ b/tests/test_batchnorm.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -33,13 +31,7 @@ def test_BatchNorm(device, size=(1, 16, 64, 64)): test_result("BatchNorm Forward", y, cpu_y) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_BatchNorm(device) test_BatchNorm(device, size=(1,64, 32, 32)) test_BatchNorm(device, size=(1, 8, 4, 4)) diff --git a/tests/test_bmm.py b/tests/test_bmm.py index d90410db..02a6460e 100644 --- a/tests/test_bmm.py +++ b/tests/test_bmm.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -42,13 +40,7 @@ def bmm(a, b, bias): test_result("BMM Forward", res, out) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_BMM(device) test_BMM(device, 2, 256, 128, 256) test_BMM(device, 2, 128, 256, 256) diff --git a/tests/test_cat.py b/tests/test_cat.py new file mode 100644 index 00000000..97fcc754 --- /dev/null +++ b/tests/test_cat.py @@ -0,0 +1,200 @@ +import argparse +from pathlib import Path + +import torch + + +def _test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + return + + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + raise RuntimeError(f"{name} mismatch") + +def test_cat_default(device): + def cat_default_fn(a, b): + return torch.cat([a, b], dim=0) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_default_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + _test_result("cat.default", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_out(device): + def cat_out_fn(a, b, out): + return torch.ops.aten.cat.out([a, b], 0, out=out) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + out_buf = torch.empty(14, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_out_fn) + + out = opt_fn(x, y, out_buf) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + _test_result("cat.out", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_dim0(device): + def cat_4d_dim0_fn(a, b): + return torch.cat([a, b], dim=0) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(3, 3, 4, 5, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_dim0_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + _test_result("cat.4d.dim0", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_dim1(device): + def cat_4d_dim1_fn(a, b): + return torch.cat([a, b], dim=1) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(2, 5, 4, 5, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_dim1_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=1) + _test_result("cat.4d.dim1", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_dim2(device): + def cat_4d_dim2_fn(a, b): + return torch.cat([a, b], dim=2) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(2, 3, 6, 5, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_dim2_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=2) + _test_result("cat.4d.dim2", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_dim3(device): + def cat_4d_dim3_fn(a, b): + return torch.cat([a, b], dim=3) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(2, 3, 4, 7, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_dim3_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=3) + _test_result("cat.4d.dim3", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_three_inputs(device): + def cat_three_inputs_fn(a, b, c): + return torch.cat([a, b, c], dim=0) + + x = torch.randn(4, 16, device=device) + y = torch.randn(5, 16, device=device) + z = torch.randn(3, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_three_inputs_fn) + + out = opt_fn(x, y, z) + + cpu_out = torch.cat([x.cpu(), y.cpu(), z.cpu()], dim=0) + _test_result("cat.three_inputs", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_four_inputs(device): + def cat_four_inputs_fn(a, b, c, d): + return torch.cat([a, b, c, d], dim=0) + + x = torch.randn(3, 16, device=device) + y = torch.randn(4, 16, device=device) + z = torch.randn(5, 16, device=device) + w = torch.randn(2, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_four_inputs_fn) + + out = opt_fn(x, y, z, w) + + cpu_out = torch.cat([x.cpu(), y.cpu(), z.cpu(), w.cpu()], dim=0) + _test_result("cat.four_inputs", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_three_inputs(device): + def cat_4d_three_inputs_fn(a, b, c): + return torch.cat([a, b, c], dim=1) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(2, 4, 4, 5, device=device) + z = torch.randn(2, 5, 4, 5, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_three_inputs_fn) + + out = opt_fn(x, y, z) + + cpu_out = torch.cat([x.cpu(), y.cpu(), z.cpu()], dim=1) + _test_result("cat.4d.three_inputs", out, cpu_out, rtol=1e-4, atol=1e-4) + +def test_cat_5d(device, dim=0): + def cat_5d_fn(a, b): + return torch.cat([a, b], dim=dim) + + x = torch.randn(2, 3, 4, 5, 6, device=device) + y = torch.randn(3, 3, 4, 5, 6, device=device) + opt_fn = torch.compile(dynamic=False)(cat_5d_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=dim) + _test_result("cat.5d.dim0", out, cpu_out, rtol=1e-4, atol=1e-4) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run cat simulation tests") + parser.add_argument( + "--case", + choices=[ + "default", "out", "4d_dim0", "4d_dim1", "4d_dim2", "4d_dim3", "5d" + "three_inputs", "four_inputs", "4d_three_inputs", "all" + ], + default="all", + help="Which cat case to run", + ) + args = parser.parse_args() + + device = torch.device("npu:0") + + if args.case in ("default", "all"): + test_cat_default(device) + if args.case in ("out", "all"): + test_cat_out(device) + if args.case in ("4d_dim0", "all"): + test_cat_4d_dim0(device) + if args.case in ("4d_dim1", "all"): + test_cat_4d_dim1(device) + if args.case in ("4d_dim2", "all"): + test_cat_4d_dim2(device) + if args.case in ("4d_dim3", "all"): + test_cat_4d_dim3(device) + if args.case in ("three_inputs", "all"): + test_cat_three_inputs(device) + if args.case in ("four_inputs", "all"): + test_cat_four_inputs(device) + if args.case in ("4d_three_inputs", "all"): + test_cat_4d_three_inputs(device) + if args.case in ("5d", "all"): + test_cat_5d(device) diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 54225747..e6b01bbd 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -49,11 +47,5 @@ def test_CNN(device): print("Max diff > ", torch.max(torch.abs(y.cpu() - cpu_y))) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_CNN(device) diff --git a/tests/test_compile_overhead.py b/tests/test_compile_overhead.py deleted file mode 100644 index 030f548e..00000000 --- a/tests/test_compile_overhead.py +++ /dev/null @@ -1,45 +0,0 @@ -import os -import time -import sys -import torch -from torchvision.models import resnet18 as model1 -import argparse -import shutil - -sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) -from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request, poisson_request_generator -CONFIG_TORCHSIM_DIR = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - -if __name__ == "__main__": - target_model1 = model1().eval() - - # Init scheduler - for i in range(1): - timestamp = time.time() # 현재 타임스탬프 (초 단위) - print(f"[{i}] Time Stamp: {timestamp:.6f}") # 소수점 6자리까지 출력 - #try: - # shutil.rmtree("/tmp/torchinductor") - #except FileNotFoundError: - # print("no cache") - scheduler = Scheduler(num_request_queue=1, max_batch=4, engine_select=Scheduler.FIFO_ENGINE, togsim_config=f"{CONFIG_TORCHSIM_DIR}/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json") - # Register compiled model - opt_model1 = torch.compile(target_model1.to(device=scheduler.execution_engine.module.custom_device(), memory_format=torch.channels_last), dynamic=False) - SchedulerDNNModel.register_model("resnet18", opt_model1) - - # Generate time stamp - for request_time in [0]*12: - # Init input data - model_input1 = torch.randn(1, 3, 224, 224) - - # Init request - new_request1 = Request("resnet18", [model_input1], [], request_queue_idx=0) - - # Add request to scheduler - print("[Reqest] Resnet18 request time: ", request_time, flush=True) - scheduler.add_request(new_request1, request_time=request_time) - - # Run scheduler - while not scheduler.is_finished(): - scheduler.schedule() - - print("Done", file=sys.stderr) \ No newline at end of file diff --git a/tests/test_conv2d.py b/tests/test_conv2d.py index e964319d..313003b1 100644 --- a/tests/test_conv2d.py +++ b/tests/test_conv2d.py @@ -1,6 +1,5 @@ import torch import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -36,23 +35,19 @@ def custom_conv2d(a, b, bias): print("Max diff > ", torch.max(torch.abs(res.cpu() - out))) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") torch._dynamo.config.cache_size_limit = 64 - test_conv2d(device, batch_size=8, in_channels=3, out_channels=32, input_size=32, kernel_size=1, stride=1, padding=0) - test_conv2d(device, batch_size=1, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=2, padding=3) - test_conv2d(device, batch_size=2, in_channels=3, out_channels=64, input_size=32//2, kernel_size=7, stride=1, padding=3) - test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) - test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) - test_conv2d(device, batch_size=2, in_channels=128, out_channels=256, input_size=13, kernel_size=5, stride=1, padding=2) - test_conv2d(device, batch_size=2, in_channels=128, out_channels=512, input_size=14, kernel_size=7, stride=1, padding=3) - test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=3, stride=2, padding=1) - test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=7, kernel_size=3, stride=2, padding=1) - test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=2, kernel_size=1, stride=1, padding=0) - test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=1, stride=2, padding=0) - test_conv2d(device, batch_size=1, in_channels=3, out_channels=768, input_size=224, kernel_size=16,stride=16, padding=0) + with torch.no_grad(): + test_conv2d(device, batch_size=8, in_channels=3, out_channels=32, input_size=32, kernel_size=1, stride=1, padding=0) + test_conv2d(device, batch_size=1, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=2, padding=3) + test_conv2d(device, batch_size=2, in_channels=3, out_channels=64, input_size=32//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=2, in_channels=128, out_channels=256, input_size=13, kernel_size=5, stride=1, padding=2) + test_conv2d(device, batch_size=2, in_channels=128, out_channels=512, input_size=14, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=3, stride=2, padding=1) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=7, kernel_size=3, stride=2, padding=1) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=2, kernel_size=1, stride=1, padding=0) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=1, stride=2, padding=0) + test_conv2d(device, batch_size=1, in_channels=3, out_channels=768, input_size=224, kernel_size=16,stride=16, padding=0) + test_conv2d(device, batch_size=1, in_channels=8, out_channels=16, input_size=1, kernel_size=1,stride=1, padding=0) diff --git a/tests/test_eager.py b/tests/test_eager.py new file mode 100644 index 00000000..b84cc6f6 --- /dev/null +++ b/tests/test_eager.py @@ -0,0 +1,12 @@ +import torch + +torch.npu.register_eager_to_compile(["aten::mul.Tensor", "aten::add.Tensor"]) + +if __name__ == "__main__": + #torch.npu.register_fallback_op("aten::add.out", my_fallback) + device = torch.device("npu:0") + x = torch.ones(10, 10).to(device) + y = torch.ones(10, 10).to(device) + z = x * y + z = x + z + print(z.cpu()) \ No newline at end of file diff --git a/tests/test_exponent.py b/tests/test_exponent.py index e60f8407..20f0a143 100644 --- a/tests/test_exponent.py +++ b/tests/test_exponent.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -27,11 +25,5 @@ def exponent(a): test_result("exponent", res, out) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_exponent(device, size=(32, 32)) diff --git a/tests/test_gqa.py b/tests/test_gqa.py new file mode 100644 index 00000000..ba262fa6 --- /dev/null +++ b/tests/test_gqa.py @@ -0,0 +1,333 @@ +import sys +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch._dynamo +import argparse + + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + + +class GQAMultiheadAttention(nn.Module): + """ + Grouped Query Attention (GQA) implementation. + Query has num_heads, but key/value have num_kv_heads (num_kv_heads < num_heads). + """ + def __init__(self, embed_dim, num_heads, num_kv_heads=None, head_dim=None, bias=True, dropout=0.0): + super().__init__() + assert embed_dim % num_heads == 0 + if head_dim is None: + head_dim = embed_dim // num_heads + assert embed_dim == num_heads * head_dim + + # If num_kv_heads is not specified, use num_heads (standard MHA) + if num_kv_heads is None: + num_kv_heads = num_heads + + assert num_kv_heads <= num_heads + assert embed_dim % num_kv_heads == 0 + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.dropout = dropout + + # QKV projection: Q has embed_dim, K and V have kv_embed_dim each + kv_embed_dim = num_kv_heads * head_dim + total_qkv_dim = embed_dim + 2 * kv_embed_dim + + self.qkv_proj = nn.Linear(embed_dim, total_qkv_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def forward(self, query, key=None, value=None, attn_mask=None, need_weights=False): + """ + Args: + query: [batch, seq_len, embed_dim] or [seq_len, batch, embed_dim] + key: optional, same shape as query + value: optional, same shape as query + attn_mask: optional attention mask + need_weights: whether to return attention weights + """ + # For compatibility with nn.MultiheadAttention API + if key is None: + key = query + if value is None: + value = query + + # Handle batch_first vs batch_second + if query.dim() == 3: + batch_first = True + batch_size, seq_len, _ = query.shape + else: + batch_first = False + seq_len, batch_size, _ = query.shape + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + + # Project QKV + # Use query for QKV projection (standard MHA/GQA pattern) + qkv = self.qkv_proj(query) # [batch, seq_len, total_qkv_dim] + + # Split into Q, K, V + kv_embed_dim = self.num_kv_heads * self.head_dim + q = qkv[:, :, :self.embed_dim] # [batch, seq_len, embed_dim] + k = qkv[:, :, self.embed_dim:self.embed_dim + kv_embed_dim] # [batch, seq_len, kv_embed_dim] + v = qkv[:, :, self.embed_dim + kv_embed_dim:] # [batch, seq_len, kv_embed_dim] + + # Reshape to multi-head format + q = q.view(batch_size, seq_len, self.num_heads, self.head_dim) # [batch, seq_len, num_heads, head_dim] + k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) # [batch, seq_len, num_kv_heads, head_dim] + v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) # [batch, seq_len, num_kv_heads, head_dim] + + # Transpose for attention: [batch, num_heads, seq_len, head_dim] + q = q.transpose(1, 2) # [batch, num_heads, seq_len, head_dim] + k = k.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim] + v = v.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim] + + # Scaled dot product attention with GQA support + # enable_gqa=True allows different number of heads for Q vs K/V + attn_output = F.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=False, + enable_gqa=(self.num_kv_heads < self.num_heads) + ) # [batch, num_heads, seq_len, head_dim] + + # Reshape back: [batch, num_heads, seq_len, head_dim] -> [batch, seq_len, embed_dim] + attn_output = attn_output.transpose(1, 2) # [batch, seq_len, num_heads, head_dim] + attn_output = attn_output.contiguous().view(batch_size, seq_len, self.embed_dim) + + # Output projection + output = self.out_proj(attn_output) # [batch, seq_len, embed_dim] + + if not batch_first: + output = output.transpose(0, 1) # [seq_len, batch, embed_dim] + + if need_weights: + # Compute attention weights for return + # This is simplified - in practice you'd want the actual attention weights + attn_weights = None + return output, attn_weights + else: + return output + + +def test_gqa_attention(device, batch=1, seq_len=32, embed_dim=768, num_heads=12, num_kv_heads=4): + """ + Test Grouped Query Attention (GQA) where num_kv_heads < num_heads. + + Args: + device: target device + batch: batch size + seq_len: sequence length + embed_dim: embedding dimension + num_heads: number of query heads + num_kv_heads: number of key/value heads (should be <= num_heads) + """ + print(f"Testing GQA Attention (batch={batch}, seq_len={seq_len}, embed_dim={embed_dim}, " + f"num_heads={num_heads}, num_kv_heads={num_kv_heads})") + + # Create GQA model + gqa = GQAMultiheadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + bias=True, + dropout=0.0 + ).eval() + + # Initialize weights + torch.nn.init.normal_(gqa.qkv_proj.weight, mean=0.0, std=0.02) + torch.nn.init.normal_(gqa.qkv_proj.bias, mean=0.0, std=0.02) + torch.nn.init.normal_(gqa.out_proj.weight, mean=0.0, std=0.02) + torch.nn.init.normal_(gqa.out_proj.bias, mean=0.0, std=0.02) + + # Create input + x = torch.randn(batch, seq_len, embed_dim) + query = x.clone() + key = x.clone() + value = x.clone() + + # Run on custom device + gqa_device = gqa.to(device) + q1, k1, v1 = query.to(device), key.to(device), value.to(device) + + compiled_gqa = torch.compile(gqa_device, dynamic=False) + with torch.no_grad(): + out_device = compiled_gqa(q1, k1, v1) + + # Run on CPU + gqa_cpu = gqa.cpu() + q2, k2, v2 = query.cpu(), key.cpu(), value.cpu() + with torch.no_grad(): + out_cpu = gqa_cpu(q2, k2, v2) + + test_result("GQA Attention", out_device, out_cpu) + print("Max diff > ", torch.max(torch.abs(out_device.cpu() - out_cpu))) + print("GQA Attention Simulation Done") + + +def test_standard_mha_via_gqa(device, batch=1, seq_len=32, embed_dim=768, num_heads=12): + """ + Test standard Multi-Head Attention using GQA with num_kv_heads == num_heads. + This should behave the same as standard MHA. + """ + print(f"Testing Standard MHA via GQA (batch={batch}, seq_len={seq_len}, " + f"embed_dim={embed_dim}, num_heads={num_heads})") + + test_gqa_attention(device, batch, seq_len, embed_dim, num_heads, num_kv_heads=num_heads) + + +def test_repeat_interleave_compilation(device, batch=1, seq_len=32, embed_dim=768, num_heads=12, num_kv_heads=4): + """ + Test that repeat_interleave operation compiles and works correctly using scaled_dot_product_attention implementation. + + This test uses the exact implementation from F.scaled_dot_product_attention to verify + that repeat_interleave works correctly when enable_gqa=True. + + Args: + device: target device + batch: batch size + seq_len: sequence length + embed_dim: embedding dimension + num_heads: number of query heads + num_kv_heads: number of key/value heads (should be < num_heads) + """ + import math + + print(f"Testing repeat_interleave compilation using scaled_dot_product_attention implementation " + f"(batch={batch}, seq_len={seq_len}, embed_dim={embed_dim}, " + f"num_heads={num_heads}, num_kv_heads={num_kv_heads})") + + head_dim = embed_dim // num_heads + assert num_kv_heads < num_heads, "num_kv_heads must be less than num_heads for GQA" + + # Create Q, K, V tensors + # Q: [batch, num_heads, seq_len, head_dim] + # K, V: [batch, num_kv_heads, seq_len, head_dim] + q = torch.randn(batch, num_heads, seq_len, head_dim) + k = torch.randn(batch, num_kv_heads, seq_len, head_dim) + v = torch.randn(batch, num_kv_heads, seq_len, head_dim) + + # Move to device + q_device = q.to(device) + k_device = k.to(device) + v_device = v.to(device) + + # Implementation from F.scaled_dot_product_attention + def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, + is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias = attn_mask + attn_bias + + if enable_gqa: + key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) + value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + return attn_weight, value, attn_weight @ value + + # Compile the function + compiled_attn = torch.compile(scaled_dot_product_attention, dynamic=False) + + # Run on custom device with enable_gqa=True + with torch.no_grad(): + output_device = compiled_attn(q_device, k_device, v_device, + attn_mask=None, dropout_p=0.0, + is_causal=False, scale=None, enable_gqa=True) + + # Run on CPU for comparison + q_cpu = q.cpu() + k_cpu = k.cpu() + v_cpu = v.cpu() + with torch.no_grad(): + output_cpu = scaled_dot_product_attention(q_cpu, k_cpu, v_cpu, + attn_mask=None, dropout_p=0.0, + is_causal=False, scale=None, enable_gqa=True) + + # Compare results + test_result("repeat_interleave in scaled_dot_product_attention", output_device[0], output_cpu[0]) + print("Max diff > ", torch.max(torch.abs(output_device[0].cpu() - output_cpu[0]))) + test_result("repeat_interleave in scaled_dot_product_attention", output_device[1], output_cpu[1]) + print("Max diff > ", torch.max(torch.abs(output_device[1].cpu() - output_cpu[1]))) + test_result("repeat_interleave in scaled_dot_product_attention", output_device[2], output_cpu[2]) + print("Max diff > ", torch.max(torch.abs(output_device[2].cpu() - output_cpu[2]))) + print("repeat_interleave compilation test Done") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--device", type=str, default="npu", help="Device to use") + parser.add_argument("--batch", type=int, default=1, help="Batch size") + parser.add_argument("--seq_len", type=int, default=32, help="Sequence length") + parser.add_argument("--embed_dim", type=int, default=768, help="Embedding dimension") + parser.add_argument("--num_heads", type=int, default=8, help="Number of query heads") + parser.add_argument("--num_kv_heads", type=int, default=4, help="Number of key/value heads") + parser.add_argument("--test_standard", action="store_true", help="Also test standard MHA via GQA") + parser.add_argument("--test_repeat_interleave", action="store_true", help="Test repeat_interleave compilation") + + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + device = torch.device("npu:0") + + test_repeat_interleave_compilation( + device=device, + batch=args.batch, + seq_len=args.seq_len, + embed_dim=args.embed_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads + ) + + # Test GQA + test_gqa_attention( + device=device, + batch=args.batch, + seq_len=args.seq_len, + embed_dim=args.embed_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads + ) + + # Optionally test standard MHA via GQA + # if args.test_standard: + # test_standard_mha_via_gqa( + # device=args.device, + # batch=args.batch, + # seq_len=args.seq_len, + # embed_dim=args.embed_dim, + # num_heads=args.num_heads + # ) diff --git a/tests/test_gqa_decode.py b/tests/test_gqa_decode.py new file mode 100644 index 00000000..7a7ab06c --- /dev/null +++ b/tests/test_gqa_decode.py @@ -0,0 +1,215 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import os +import sys +import math +import argparse +from Simulator.simulator import TOGSimulator +device = torch.device("npu:0") +# ───────────────────────────────────────────────────────────────────────────── +# Optimized: Flash-Decode style — tile S upfront, batch in B dimension +# ───────────────────────────────────────────────────────────────────────────── + +class GQADecodeOptimized(nn.Module): + """Flash-Decode style GQA decode for multi-core NPU. + + Splits the KV-cache sequence into n_tiles chunks and folds them into the + BMM batch dimension (B_total = H_kv × n_tiles). Both the QK and SV + matrix multiplications are issued as a *single* batched BMM with a short + inner-K loop, so the NPU scheduler can distribute all B_total tiles across + available cores simultaneously. + + Improvement over GQABaseline + ───────────────────────────── + Baseline QK : B=H_kv=1, M=G, N=S(large), K=D → 640 N-tile iters on 1 batch + Optimized QK: B=H_kv*n_tiles, M=G, N=T(small), K=D → n_tiles batch slots for cores + + Baseline SV : B=H_kv=1, M=G, N=D, K=S → K-loop=640, only 8 outer tiles + Optimized SV: B=H_kv*n_tiles, M=G, N=D, K=T → K-loop=T/TILE_K, n_tiles outer tiles + + Memory layout improvements + ────────────────────────── + • K/V tiles are generated with a single contiguous view+reshape (no mid-loop transpose). + • Avoids materializing the full score tensor [H_kv, G, S] in DRAM before tiling. + • Softmax intermediates are kept in smaller [B_total, G, T] buffers. + + Input conventions + ───────────────── + q : [H_kv, G, D] – one decode-step query token per KV head + k : [H_kv, S, D] – KV-cache keys (NOT pre-transposed) + v : [H_kv, S, D] – KV-cache values + + tile_size selection + ─────────────────── + Ideal: tile_size = round_up(S * H_kv / num_cores, vpu_num_lanes) + so that B_total ≈ num_cores. Must also satisfy the SPAD budget: + (G*T + T*D + G*D) * bytes ≤ spad_per_core (for sub-tile occupancy) + Default 512 works for (G=5, D=128, fp16, 16-lane × 8 KB/lane SPAD). + """ + + def __init__(self, tile_size: int = 512): + super().__init__() + self.tile_size = tile_size + + def forward( + self, + q: torch.Tensor, # [H_kv, G, D] + k: torch.Tensor, # [H_kv, S, D] + v: torch.Tensor, # [H_kv, S, D] + scale: float, + ) -> torch.Tensor: + H_kv, G, D = q.shape + _, S, _ = k.shape + T = self.tile_size + n_tiles = (S + T - 1) // T + pad_len = n_tiles * T - S + B_total = H_kv * n_tiles + + # ── 1. Pad S → multiple of T ─────────────────────────────────────── + if pad_len > 0: + k = F.pad(k, (0, 0, 0, pad_len)) # [H_kv, S', D] + v = F.pad(v, (0, 0, 0, pad_len)) # [H_kv, S', D] + + # ── 2. Tile K, V → [B_total, T, D] (contiguous, no copy) ───────── + # k is [H_kv, S', D]; view splits S' → n_tiles×T along dim-1 + k_tiles = k.view(H_kv, n_tiles, T, D).reshape(B_total, T, D) + v_tiles = v.view(H_kv, n_tiles, T, D).reshape(B_total, T, D) + + # ── 3. Expand Q → [B_total, G, D] ───────────────────────────────── + # expand: zero-copy view; reshape: contiguous copy (small: B_total*G*D elems) + q_exp = q.unsqueeze(1).expand(H_kv, n_tiles, G, D).reshape(B_total, G, D) + + # ── 4. Batched QK BMM ────────────────────────────────────────────── + # [B_total, G, D] × [B_total, D, T] → [B_total, G, T] + # NPU mapping: B=B_total, M=G, N=T, K=D + # → outer tiles = B_total × M_tiles × N_tiles (all parallelizable) + # → inner K-loop = D/TILE_K (short, D=128) + k_t = k_tiles.transpose(1, 2) # [B_total, D, T] + scores = torch.bmm(q_exp, k_t) * scale # [B_total, G, T] + + # ── 5. Tile-local softmax (fp32 accumulation) ────────────────────── + # All ops are elementwise on [B_total, G, T] → torch.compile fuses them + scores_f32 = scores.float() + local_max = scores_f32.amax(dim=-1, keepdim=True) # [B_total, G, 1] + local_exp = (scores_f32 - local_max).exp() # [B_total, G, T] + local_sum = local_exp.sum(dim=-1, keepdim=True) # [B_total, G, 1] + + # ── 6. Batched SV BMM ────────────────────────────────────────────── + # [B_total, G, T] × [B_total, T, D] → [B_total, G, D] + # NPU mapping: B=B_total, M=G, N=D, K=T + # → outer tiles = B_total × M_tiles × N_tiles (parallelizable) + # → inner K-loop = T/TILE_K (controlled, T≪S) + sv = torch.bmm(local_exp.to(q.dtype), v_tiles) # [B_total, G, D] + + # ── 7. Online-softmax global reduction (elementwise, fused) ──────── + local_max = local_max.view(H_kv, n_tiles, G, 1) + local_sum = local_sum.view(H_kv, n_tiles, G, 1) + sv = sv.view(H_kv, n_tiles, G, D) + + global_max = local_max.amax(dim=1, keepdim=True) # [H_kv, 1, G, 1] + rescale = (local_max - global_max).exp() # [H_kv, n_tiles, G, 1] + corrected_sv = (sv * rescale).sum(dim=1) # [H_kv, G, D] + corrected_sum = (local_sum * rescale).sum(dim=1) # [H_kv, G, 1] + + return (corrected_sv / corrected_sum.clamp_min(1e-12)).to(q.dtype) + + +# ───────────────────────────────────────────────────────────────────────────── +# Test +# ───────────────────────────────────────────────────────────────────────────── + +MODEL_CONFIGS = { + "LLAMA4_TP8": { + "HEAD_DIM": 128, + "NUM_HEADS": 5, # = 40 total / TP8 + "NUM_KV_HEADS": 1, # = 8 total / TP8 + }, + "QWEN3-235B_TP4": { + "HEAD_DIM": 128, + "NUM_HEADS": 16, + "NUM_KV_HEADS": 1, + }, + "GPT-OSS_TP1": { + "HEAD_DIM": 64, + "NUM_HEADS": 64, + "NUM_KV_HEADS": 8, + }, + "GPT-OSS_TP2": { + "HEAD_DIM": 64, + "NUM_HEADS": 32, + "NUM_KV_HEADS": 4, + }, + "GPT-OSS_TP4": { + "HEAD_DIM": 64, + "NUM_HEADS": 16, + "NUM_KV_HEADS": 2, + }, + "GPT-OSS_TP8": { + "HEAD_DIM": 64, + "NUM_HEADS": 8, + "NUM_KV_HEADS": 1, + }, +} + + +def _make_inputs(cfg, seq_len, dtype): + H_kv = cfg["NUM_KV_HEADS"] + G = cfg["NUM_HEADS"] // cfg["NUM_KV_HEADS"] + D = cfg["HEAD_DIM"] + scale = 1.0 / math.sqrt(D) + + q = torch.randn(H_kv, G, D, dtype=dtype) + k = torch.randn(H_kv, seq_len, D, dtype=dtype) # NOT pre-transposed + v = torch.randn(H_kv, seq_len, D, dtype=dtype) + return q, k, v, scale + + +def test_gqa_decode_optimized(model, device, seq_len: int = 10240, tile_size: int = 512): + + cfg = MODEL_CONFIGS[model] if model is not None else MODEL_CONFIGS["LLAMA4_TP8"] + dtype = torch.float16 + + model = GQADecodeOptimized(tile_size=tile_size).eval() + + # ── NPU run ──────────────────────────────────────────────────────────── + q, k, v, scale = _make_inputs(cfg, seq_len, dtype) + model_dev = model.to(device) + compiled = torch.compile(model_dev, dynamic=False) + + q_dev, k_dev, v_dev = q.to(device), k.to(device), v.to(device) + with torch.no_grad(): + with TOGSimulator(): + out_dev = compiled(q_dev, k_dev, v_dev, scale=scale) + + # ── CPU reference ────────────────────────────────────────────────────── + with torch.no_grad(): + out_cpu = model.cpu()(q, k, v, scale=scale) + + max_diff = (out_dev.cpu() - out_cpu).abs().max().item() + + with torch.no_grad():#CPU reference + out_library = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, enable_gqa=True) + + max_diff_library = (out_library.cpu() - out_cpu).abs().max().item() + + print(f"[GQADecodeOptimized] seq_len={seq_len}, tile_size={tile_size}") + print(f" max |npu - cpu| = {max_diff:.6f}") + print(f" npu out max = {out_dev.cpu().abs().max().item():.6f}") + print(f" cpu out max = {out_cpu.abs().max().item():.6f}") + print(f" library out max = {out_library.abs().max().item():.6f}") + print(" PASS" if max_diff < 0.05 else " FAIL (diff too large)") + + + + +if __name__ == "__main__": + argparser = argparse.ArgumentParser(description="Test GQA Attention Implementations") + argparser.add_argument("--model", type=str, default="LLAMA4_TP8", choices=MODEL_CONFIGS.keys(), help="Model configuration to test") + argparser.add_argument("--context_length", type=int, default=10240, help="Sequence length (context length) for the attention test") + argparser.add_argument("--tile_size", type=int, default=4096, help="Tile size for the optimized attention implementation") + args = argparser.parse_args() + model = args.model + base_dir = os.environ.get("TORCHSIM_DIR", default="/workspace/PyTorchSim") + sys.path.append(base_dir) + test_gqa_decode_optimized(model=model, device=device, seq_len=args.context_length, tile_size=args.tile_size) diff --git a/tests/test_group_conv.py b/tests/test_group_conv.py new file mode 100644 index 00000000..4f97cff6 --- /dev/null +++ b/tests/test_group_conv.py @@ -0,0 +1,79 @@ +import torch +import torch._dynamo +from Simulator.simulator import TOGSimulator + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def test_group_convolution( + device, + groups=2, + stride=1, + padding=1, + batch_size=2, + c_per_group=8, + out_per_group=12, + spatial=16, + kernel_size=3, + seed=0, +): + """``torch.compile`` on NPU vs CPU reference — same structure as ``test_matmul`` / ``test_conv2d``.""" + + def custom_group_conv(a, weight, bias): + return torch.convolution( + a, + weight, + bias, + (stride, stride), + (padding, padding), + (1, 1), + False, + (0, 0), + groups, + ) + + torch.manual_seed(seed) + c_in = c_per_group * groups + c_out = out_per_group * groups + k = kernel_size + x = torch.randn(batch_size, c_in, spatial, spatial) + wgt = torch.randn(c_out, c_in // groups, k, k) + b = torch.randn(c_out) + + x1 = x.to(device=device, memory_format=torch.channels_last) + w1 = wgt.to(device=device, memory_format=torch.channels_last) + b1 = b.to(device=device) + x2 = x.to("cpu", memory_format=torch.channels_last) + w2 = wgt.to("cpu", memory_format=torch.channels_last) + b2 = b.to("cpu") + + opt_fn = torch.compile(dynamic=False)(custom_group_conv) + res = opt_fn(x1, w1, b1) + y = custom_group_conv(x2, w2, b2) + label = f"Group Conv Forward (groups={groups}, stride={stride}, pad={padding})" + test_result(label, res, y, rtol=1e-3, atol=1e-3) + print("Max diff > ", torch.max(torch.abs(res.cpu() - y))) + + +if __name__ == "__main__": + device = torch.device("npu:0") + with torch.no_grad(): + #test_group_convolution(device, batch_size=1, groups=2, stride=1, padding=1, seed=0) + #test_group_convolution(device, batch_size=1, groups=4, stride=1, padding=1, seed=1) + #test_group_convolution(device, batch_size=1, groups=2, stride=2, padding=1, seed=2) + test_group_convolution(device, batch_size=1, groups=240, stride=2, padding=1, seed=2, c_per_group=1, out_per_group=1, spatial=40) + + #test_group_convolution(device, batch_size=1, groups=240, stride=2, padding=1, seed=2, c_per_group=1, out_per_group=1) + print("test_group_conv_decomposition: all passed") diff --git a/tests/test_hetro.py b/tests/test_hetro.py index a0716e2d..eaf145d4 100644 --- a/tests/test_hetro.py +++ b/tests/test_hetro.py @@ -2,28 +2,31 @@ import sys import torch import argparse -sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) -from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request + +sys.path.append(os.environ.get("TORCHSIM_DIR", default="/workspace/PyTorchSim")) + +from Simulator.simulator import TOGSimulator from test_stonne import sparse_matmul + def custom_matmul(a, b): return torch.matmul(a, b) + + torch.manual_seed(0) -CONFIG_TORCHSIM_DIR = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +CONFIG_TORCHSIM_DIR = os.environ.get("TORCHSIM_DIR", default="/workspace/PyTorchSim") if __name__ == "__main__": parser = argparse.ArgumentParser(description="") parser.add_argument("--M", type=int, default=128, help="Batch size") parser.add_argument("--N", type=int, default=128, help="Input layer size") parser.add_argument("--K", type=int, default=128, help="Hidden layer size") - parser.add_argument("--sparsity", type=float, default=0.9, help="Output layer size") - parser.add_argument("--config", type=str, default="stonne_big_c1_simple_noc.json", help="Output layer size") - parser.add_argument("--mode", type=int, default=0, help="Output layer size") + parser.add_argument("--sparsity", type=float, default=0.9, help="Sparsity") + parser.add_argument("--config", type=str, default="stonne_big_c1_simple_noc.yml", help="TOGSim config file name under configs/") + parser.add_argument("--mode", type=int, default=0, help="0=spmm only, 1=dense matmul only, 2=both partitions") args = parser.parse_args() - M = args.M - N = args.N - K = args.K + M, N, K = args.M, args.N, args.K sparsity = args.sparsity mode = args.mode config_path = f"{CONFIG_TORCHSIM_DIR}/configs/{args.config}" @@ -33,45 +36,30 @@ def custom_matmul(a, b): print("K: ", K) print("sparsity: ", sparsity) - with torch.no_grad(): - # Init scheduler - scheduler = Scheduler(num_request_queue=2, engine_select=Scheduler.FIFO_ENGINE, - togsim_config=config_path) - - # Register compiled model - opt_model1 = torch.compile(custom_matmul) - opt_model2 = torch.compile(sparse_matmul) - SchedulerDNNModel.register_model("matmul", opt_model1) - SchedulerDNNModel.register_model("spmm", opt_model2) + device = torch.device("npu:0") - # Init input data - for i in range(1): - dense_input1 = torch.randn(M, K) - dense_input2 = torch.randn(K, N) + opt_model1 = torch.compile(custom_matmul) + opt_model2 = torch.compile(sparse_matmul) - sparse_input1 = torch.randn(128, 128) - sparse_input2 = torch.randn(128, 128) - mask1 = torch.rand(sparse_input1.shape) > sparsity - mask2 = torch.rand(sparse_input2.shape) > sparsity + dense_input1 = torch.randn(M, K, device=device) + dense_input2 = torch.randn(K, N, device=device) - sparse_input1 = sparse_input1 * mask1 - sparse_input2 = sparse_input2 * mask2 + sparse_input1 = torch.randn(128, 128, device=device) + sparse_input2 = torch.randn(128, 128, device=device) + mask1 = torch.rand(sparse_input1.shape, device=device) > sparsity + mask2 = torch.rand(sparse_input2.shape, device=device) > sparsity + sparse_input1 = sparse_input1 * mask1 + sparse_input2 = sparse_input2 * mask2 - # Init request + with torch.no_grad(): + with TOGSimulator(config_path=config_path): if mode == 0: - new_request1 = Request("spmm", [sparse_input1, sparse_input2], [], request_queue_idx=0) - scheduler.add_request(new_request1, request_time=0) + torch.npu.launch_model(opt_model2, sparse_input1, sparse_input2, stream_index=0, timestamp=0) elif mode == 1: - new_request2 = Request("matmul", [dense_input1, dense_input2], [], request_queue_idx=0) - scheduler.add_request(new_request2, request_time=0) + torch.npu.launch_model(opt_model1, dense_input1, dense_input2, stream_index=0, timestamp=0) elif mode == 2: - new_request1 = Request("spmm", [sparse_input1, sparse_input2], [], request_queue_idx=0) - new_request2 = Request("matmul", [dense_input1, dense_input2], [], request_queue_idx=1) - - # Add request to scheduler - scheduler.add_request(new_request1, request_time=0) - scheduler.add_request(new_request2, request_time=0) - - # Run scheduler - while not scheduler.is_finished(): - scheduler.schedule() \ No newline at end of file + torch.npu.launch_model(opt_model2, sparse_input1, sparse_input2, stream_index=0, timestamp=0) + torch.npu.launch_model(opt_model1, dense_input1, dense_input2, stream_index=1, timestamp=0) + else: + raise ValueError(f"unknown mode {mode}") + torch.npu.synchronize() diff --git a/tests/test_indirect_access.py b/tests/test_indirect_access.py index c6afaf86..95167d1e 100644 --- a/tests/test_indirect_access.py +++ b/tests/test_indirect_access.py @@ -1,7 +1,5 @@ import torch import copy -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -43,13 +41,45 @@ def test_embedding(device, vocab_size, dim): cpu_res = cpu_emb(cpu_prompt) test_result("Embedding", res, cpu_res) -if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) +def test_scatter_add(device, num_tokens=256, hidden_size=256, num_assignments=3, dtype=torch.float32, seed=0): + torch.manual_seed(seed) + + def scatter_only(out, token_indices, weighted_output): + # token_indices: [N] (long), weighted_output: [N, H] + out.index_add_(0, token_indices, weighted_output) + return out + + out = torch.randn(num_tokens, hidden_size, dtype=dtype) + out_cp = out.clone() + token_indices = torch.randint(0, num_tokens, (num_assignments,)) + weighted_output = torch.randn(num_assignments, hidden_size, dtype=dtype) + + cpu_out = scatter_only(out, token_indices, weighted_output) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + out = out_cp.to(device=device) + token_indices = token_indices.to(device=device) + weighted_output = weighted_output.to(device=device) + opt_fn = torch.compile(dynamic=False)(scatter_only) + res = opt_fn(out, token_indices, weighted_output) + test_result("ScatterAdd(index_add_)", res, cpu_out) + +def test_scatter_full(device, size=(128, 128)): + def vectoradd(a, idx, b): + a[idx, :] = b + return a + x = torch.randn(size, dtype=torch.float32).to(device=device) + x_cpu = x.clone().cpu() + idx = torch.randint(0,128, [128]).to(device=device) + y = torch.randn(size[1], dtype=torch.float32).to(device=device) + opt_fn = torch.compile(dynamic=False)(vectoradd) + res = opt_fn(x, idx, y) + out = vectoradd(x_cpu, idx.cpu(), y.cpu()) + test_result("Indirect VectorAdd", res, out) + +if __name__ == "__main__": + device = torch.device("npu:0") + test_scatter_full(device) + test_scatter_full(device, size=(2048, 2048)) + test_scatter_add(device) test_indirect_vectoradd(device) #test_embedding(device, 1024, 2048) \ No newline at end of file diff --git a/tests/test_layernorm.py b/tests/test_layernorm.py index 28e38d37..3db27dc5 100644 --- a/tests/test_layernorm.py +++ b/tests/test_layernorm.py @@ -31,18 +31,14 @@ def test_LayerNorm(device, size=(64, 64)): test_result("LayerNorm Forward", y, cpu_y) if __name__ == "__main__": - import os - import sys import argparse - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") parser.add_argument('--shape', type=str, help="Shape of the tensor in the format (batch_size, features)", default="(512,768)") args = parser.parse_args() shape = tuple(map(int, args.shape.strip('()').split(','))) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() - #test_LayerNorm(device) - test_LayerNorm(device, shape) + device = torch.device("npu:0") + with torch.no_grad(): + #test_LayerNorm(device) + test_LayerNorm(device, shape) diff --git a/tests/test_matmul.py b/tests/test_matmul.py index cd30bd30..a5bdf422 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -90,13 +88,7 @@ def custom_linear(a, b, bias): test_result("Linear Forward", res, y) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_matmul(device, 32, 32, 32) test_matmul(device, 128, 128, 128) test_matmul(device, 256, 256, 256) diff --git a/tests/test_mlp.py b/tests/test_mlp.py index 423d6e8e..e3f79561 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -1,7 +1,5 @@ import copy import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -105,13 +103,7 @@ def test_optimizer(device): test_result("Optimizer", model.linear1.weight, cpu_model.linear1.weight) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_mlp(device) test_mlp_inf(device, batch_size=1, input_size=256, hidden_size=512, output_size=256) test_mlp_inf(device, batch_size=8, input_size=256, hidden_size=512, output_size=256) diff --git a/tests/test_pool.py b/tests/test_pool.py index f5505dba..2848e04b 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -43,13 +41,7 @@ def avgpool(a): test_result("Avgpool Forward", res, out) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") #test_maxpool(device, b=1, c=8, h=16, w=16) #test_maxpool(device, b=1, c=8, h=112, w=112) test_avgpool(device, b=1, c=512, h=7, w=7) diff --git a/tests/test_reduce.py b/tests/test_reduce.py index 4781112d..07f8fef2 100644 --- a/tests/test_reduce.py +++ b/tests/test_reduce.py @@ -37,19 +37,14 @@ def reduce_sum(a, dim, keepdim): test_result("ReduceMax", res, out) if __name__ == "__main__": - import os - import sys import argparse - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") parser.add_argument('--shape', type=str, default="(128,768)") args = parser.parse_args() shape = tuple(map(int, args.shape.strip('()').split(','))) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_reduce_sum(device, (29, 47), 1, keepdim=True) test_reduce_sum(device, (17, 68), 0, keepdim=True) test_reduce_sum(device, (327, 447), 1, keepdim=True) diff --git a/tests/test_resnet.py b/tests/test_resnet.py index c83f13ba..2459cd58 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -49,7 +49,5 @@ def test_resnet(device, batch=1, model_type='resnet18'): args = args.parse_args() sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_resnet(device, model_type=args.model_type) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 4860de56..beab8054 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1,43 +1,25 @@ import os -import sys import torch from torchvision.models import resnet18 as model1 from test_transformer import EncoderBlock as model2 +from Simulator.simulator import TOGSimulator base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') -sys.path.append(base_path) -from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request -config = f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.json' +config = f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.yml' target_model1 = model1().eval() target_model2 = model2(768, 12).eval() -# Init scheduler -scheduler = Scheduler(num_request_queue=2, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) -# Register compiled model -opt_model1 = torch.compile(target_model1.to(device=scheduler.execution_engine.module.custom_device(), memory_format=torch.channels_last)) -opt_model2 = torch.compile(target_model2.to(device=scheduler.execution_engine.module.custom_device())) -SchedulerDNNModel.register_model("resnet18", opt_model1) -SchedulerDNNModel.register_model("bert", opt_model2) - -# Init input data -model_input1 = torch.randn(1, 3, 224, 224) -model_input2 = torch.randn(128, 768) - -# Init request -new_request1 = Request("resnet18", [model_input1], [], request_queue_idx=0) -new_request2 = Request("bert", [model_input2], [], request_queue_idx=1) -new_request3 = Request("resnet18", [model_input1], [], request_queue_idx=0) -new_request4 = Request("bert", [model_input2], [], request_queue_idx=1) - -# Add request to scheduler -scheduler.add_request(new_request1, request_time=0) -scheduler.add_request(new_request2, request_time=0) -scheduler.add_request(new_request3, request_time=0) -scheduler.add_request(new_request4, request_time=0) - -# Run scheduler -while not scheduler.is_finished(): - scheduler.schedule() - -print("Done") \ No newline at end of file +device = torch.device("npu:0") +opt_model1 = torch.compile(target_model1.to(device=device, memory_format=torch.channels_last)) +opt_model2 = torch.compile(target_model2.to(device=device)) +model_input1 = torch.randn(1, 3, 224, 224).to(device=device) +model_input2 = torch.randn(128, 768).to(device=device) + +with TOGSimulator(config_path=config): + torch.npu.launch_model(opt_model1, model_input1, stream_index=0, timestamp=0) + torch.npu.launch_model(opt_model2, model_input2, stream_index=1, timestamp=0) + torch.npu.synchronize() + torch.npu.launch_model(opt_model1, model_input1, stream_index=0, timestamp=0) + torch.npu.launch_model(opt_model2, model_input2, stream_index=1, timestamp=0) +print("Done") diff --git a/tests/test_scheduler_batching.py b/tests/test_scheduler_batching.py deleted file mode 100644 index 53f9256d..00000000 --- a/tests/test_scheduler_batching.py +++ /dev/null @@ -1,41 +0,0 @@ -import os -import sys -import torch -from torchvision.models import resnet18 as model1 -import argparse - -sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) -from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request, poisson_request_generator -CONFIG_TORCHSIM_DIR = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Poisson Request Generator (ms)") - parser.add_argument("lambda_requests", nargs="?", type=int, help="Average requests per second (λ)", default=2000) - parser.add_argument("max_time", nargs="?", type=int, help="Maximum simulation time in milliseconds", default=30) - - args = parser.parse_args() - target_model1 = model1().eval() - - # Init scheduler - scheduler = Scheduler(num_request_queue=1, max_batch=32, engine_select=Scheduler.FIFO_ENGINE, togsim_config=f"{CONFIG_TORCHSIM_DIR}/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json") - # Register compiled model - opt_model1 = torch.compile(target_model1.to(device=scheduler.execution_engine.module.custom_device(), memory_format=torch.channels_last), dynamic=False) - SchedulerDNNModel.register_model("resnet18", opt_model1) - - # Generate time stamp - for request_time in poisson_request_generator(args.lambda_requests, args.max_time): - # Init input data - model_input1 = torch.randn(1, 3, 224, 224) - - # Init request - new_request1 = Request("resnet18", [model_input1], [], request_queue_idx=0) - - # Add request to scheduler - print("[Reqest] Resnet18 request time: ", request_time, flush=True) - scheduler.add_request(new_request1, request_time=request_time) - - # Run scheduler - while not scheduler.is_finished(): - scheduler.schedule() - - print("Done", file=sys.stderr) \ No newline at end of file diff --git a/tests/test_sdpa.py b/tests/test_sdpa.py new file mode 100644 index 00000000..c4825731 --- /dev/null +++ b/tests/test_sdpa.py @@ -0,0 +1,145 @@ +import sys +import os +import torch +import torch._dynamo +import torch.nn.functional as F + +base_dir = os.environ.get("TORCHSIM_DIR", default="/workspace/PyTorchSim") +sys.path.append(base_dir) + +device = torch.device("npu:0") + +# --------------------------------------------------------------------------- +# Default sweep configs - edit here to change what gets tested +# --------------------------------------------------------------------------- +SDPA_DEFAULTS = dict( + n_batch_list = [1, 4, 8, 16], + n_head_list = [4, 6, 8, 12], + n_token_list = [128, 256, 512, 1024], + head_dim_list = [32, 64, 128], + is_causal = False, +) + +GQA_DEFAULTS = dict( + batch_list = [1], + num_kv_heads = 1, + gqa_ratios = [4, 5, 8, 16], # Hq = ratio * num_kv_heads + seq_len_list = [128, 256, 1024], + head_dim_list = [64, 128], + query_len = 1, # decode shape: Lq == 1 + is_causal = True, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def clear_caches(): + from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache + from torch._inductor.codecache import FxGraphCache + AOTAutogradCache.clear() + torch._dynamo.reset() + os.environ["TORCHINDUCTOR_CACHE"] = "0" + FxGraphCache.clear() + + +def assert_close(name, out, cpu_out, rtol=1e-4, atol=1e-4): + msg = f"|{name} Test Passed|" + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + print("-" * len(msg)) + print(msg) + print("-" * len(msg)) + else: + print(f"[FAIL] {name}") + print(" device out:", out.cpu()) + print(" cpu out:", cpu_out) + exit(1) + + +def _run_sdpa(device, q, k, v, **kwargs): + """Compile and run SDPA on device; return result on device.""" + opt_fn = torch.compile(dynamic=False)(F.scaled_dot_product_attention) + return opt_fn(q.to(device), k.to(device), v.to(device), **kwargs) + + +def _cpu_sdpa(q, k, v, **kwargs): + """Run reference SDPA on CPU.""" + return F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), **kwargs) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- +def test_sdpa( + device, + n_batch_list = SDPA_DEFAULTS["n_batch_list"], + n_head_list = SDPA_DEFAULTS["n_head_list"], + n_token_list = SDPA_DEFAULTS["n_token_list"], + head_dim_list = SDPA_DEFAULTS["head_dim_list"], + is_causal = SDPA_DEFAULTS["is_causal"], +): + torch.manual_seed(0) + sdpa_kwargs = dict(attn_mask=None, dropout_p=0.0, is_causal=is_causal) + + for B in n_batch_list: + for H in n_head_list: + for S in n_token_list: + for D in head_dim_list: + clear_caches() + q = torch.rand(B, H, S, D, dtype=torch.float32) + k = torch.rand(B, H, S, D, dtype=torch.float32) + v = torch.rand(B, H, S, D, dtype=torch.float32) + + out = _run_sdpa(device, q, k, v, **sdpa_kwargs) + cpu_out = _cpu_sdpa(q, k, v, **sdpa_kwargs) + + assert_close(f"SDPA(B:{B}, H:{H}, S:{S}, D:{D})", out, cpu_out) + + print("All SDPA tests passed!") + + +def test_gqa( + device, + batch_list = GQA_DEFAULTS["batch_list"], + num_kv_heads = GQA_DEFAULTS["num_kv_heads"], + gqa_ratios = GQA_DEFAULTS["gqa_ratios"], + seq_len_list = GQA_DEFAULTS["seq_len_list"], + head_dim_list= GQA_DEFAULTS["head_dim_list"], + query_len = GQA_DEFAULTS["query_len"], + is_causal = GQA_DEFAULTS["is_causal"], +): + """ + GQA sweep: q shape (B, Hq, Lq, D), kv shape (B, H, S, D). + Hq = ratio * num_kv_heads for each ratio in gqa_ratios. + """ + torch.manual_seed(0) + sdpa_kwargs = dict(attn_mask=None, dropout_p=0.0, is_causal=is_causal, enable_gqa=True) + + for B in batch_list: + for S in seq_len_list: + for D in head_dim_list: + for ratio in gqa_ratios: + Hq = ratio * num_kv_heads + clear_caches() + q = torch.rand(B, Hq, query_len, D, dtype=torch.float32) + k = torch.rand(B, num_kv_heads, S, D, dtype=torch.float32) + v = torch.rand(B, num_kv_heads, S, D, dtype=torch.float32) + + out = _run_sdpa(device, q, k, v, **sdpa_kwargs) + cpu_out = _cpu_sdpa(q, k, v, **sdpa_kwargs) + + assert_close( + f"GQA(B:{B}, Hq:{Hq}, H:{num_kv_heads}, S:{S}, D:{D})", + out, cpu_out, + ) + + print("All GQA tests passed!") + + +if __name__ == "__main__": + with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.FLASH_ATTENTION]): + test_sdpa(device) + #test_gqa(device) + + # Example: quick single-config run + # test_gqa(device, batch_list=[1], gqa_ratios=[5], seq_len_list=[32], head_dim_list=[128]) diff --git a/tests/test_single_perceptron.py b/tests/test_single_perceptron.py index beab1c54..7d3401a3 100644 --- a/tests/test_single_perceptron.py +++ b/tests/test_single_perceptron.py @@ -1,7 +1,5 @@ import copy import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -78,11 +76,5 @@ def weight_update(a, b, lr): # plt.savefig('result.png') if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_single_perceptron(device) diff --git a/tests/test_softmax.py b/tests/test_softmax.py index e6e8cc1e..2dca97b7 100644 --- a/tests/test_softmax.py +++ b/tests/test_softmax.py @@ -42,25 +42,29 @@ def test_softmax(device, size=(128, 128), dim=1): #cpu_y = softmax3(x2, cpu_max, cpu_sum) #test_result("Softmax", y, cpu_y) - opt_fn = torch.compile(dynamic=False)(torch.nn.functional.softmax) - y = opt_fn(x1, dim=dim) + class SoftmaxModule(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.nn.functional.softmax(x, dim=self.dim) + + softmax_module = SoftmaxModule(dim=dim).to(device) + opt_fn = torch.compile(dynamic=False)(softmax_module) + y = opt_fn(x1) cpu_y = torch.nn.functional.softmax(x2, dim=dim) test_result("Softmax", y, cpu_y) if __name__ == "__main__": - import os - import sys import argparse - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") parser.add_argument('--shape', type=str, help="Shape of the tensor in the format (batch_size, features)", default="(512,768)") args = parser.parse_args() shape = tuple(map(int, args.shape.strip('()').split(','))) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_softmax(device, size=(64, 128)) test_softmax(device, size=(64, 128), dim=0) test_softmax(device, size=(256, 128)) diff --git a/tests/test_sort.py b/tests/test_sort.py new file mode 100644 index 00000000..5bce2532 --- /dev/null +++ b/tests/test_sort.py @@ -0,0 +1,121 @@ +import argparse +import torch + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out:", out.cpu()) + print("cpu out:", cpu_out) + raise SystemExit(1) + + +def test_equal(name, out, cpu_out): + if torch.equal(out.cpu(), cpu_out): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out:", out.cpu()) + print("cpu out:", cpu_out) + raise SystemExit(1) + +def test_sort(device, size=(128, 128), dim=-1, descending=False, stable=True): + def sort_test(x): + return torch.sort(x, dim=dim, descending=descending, stable=stable) + + x = torch.randn(size, dtype=torch.float32) + x_npu = x.to(device=device) + + opt_sort = torch.compile(dynamic=False)(sort_test) + out_values, out_indices = opt_sort(x_npu) + ref_values, ref_indices = torch.sort(x, stable=stable, dim=dim, descending=descending) + + prefix = "Sort.stable" if stable else "Sort.unstable" + test_result(f"{prefix}/values size={size}, dim={dim}, desc={descending}", out_values, ref_values) + if stable: + test_result(f"{prefix}/indices size={size}, dim={dim}, desc={descending}", out_indices, ref_indices) + else: + # Unstable sort does not guarantee tie ordering; validate index-value consistency instead. + gathered = torch.gather(x, dim, out_indices.cpu()) + test_result(f"{prefix}/indices_gather size={size}, dim={dim}, desc={descending}", gathered, out_values.cpu()) + + +def test_sort_stable_suite(device): + # Keep sort-axis sizes compatible with backend constraints (vector-size multiple). + cases = [ + {"size": (64,), "dim": 0, "descending": False}, # 1D + {"size": (4, 64), "dim": 1, "descending": True}, # 2D, last dim + {"size": (2, 8, 32), "dim": 2, "descending": False}, # 3D, last dim + {"size": (2, 16, 4), "dim": 1, "descending": True}, # 3D, middle dim + {"size": (2, 4, 8, 32), "dim": 3, "descending": False}, # 4D, last dim + {"size": (4, 2, 32, 8), "dim": 2, "descending": True}, # 4D, inner dim + ] + for case in cases: + test_sort( + device=device, + size=case["size"], + dim=case["dim"], + descending=case["descending"], + stable=True, + ) + + +def test_sort_duplicate_cases(device): + duplicate_cases = [ + {"size": (64,), "dim": 0, "descending": False}, + {"size": (4, 64), "dim": 1, "descending": True}, + {"size": (2, 8, 32), "dim": 2, "descending": False}, + ] + for case in duplicate_cases: + base = torch.arange(case["size"][case["dim"]], dtype=torch.int64) % 7 + view_shape = [1] * len(case["size"]) + view_shape[case["dim"]] = case["size"][case["dim"]] + x = base.view(view_shape).expand(case["size"]).to(torch.float32) + noise = torch.randn(case["size"], dtype=torch.float32) * 0.0 + x = x + noise + + def sort_test(inp): + return torch.sort(inp, dim=case["dim"], descending=case["descending"], stable=True) + + out_values, out_indices = torch.compile(dynamic=False)(sort_test)(x.to(device=device)) + ref_values, ref_indices = torch.sort( + x, dim=case["dim"], descending=case["descending"], stable=True + ) + test_result(f"Sort.dup/stable_values {case}", out_values, ref_values) + test_equal(f"Sort.dup/stable_indices {case}", out_indices, ref_indices) + + def sort_test_unstable(inp): + return torch.sort(inp, dim=case["dim"], descending=case["descending"], stable=False) + + out_values_u, out_indices_u = torch.compile(dynamic=False)(sort_test_unstable)(x.to(device=device)) + ref_values_u, _ = torch.sort(x, dim=case["dim"], descending=case["descending"], stable=False) + test_result(f"Sort.dup/unstable_values {case}", out_values_u, ref_values_u) + gathered_u = torch.gather(x, case["dim"], out_indices_u.cpu()) + test_result(f"Sort.dup/unstable_gather {case}", gathered_u, out_values_u.cpu()) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run sort tests") + parser.add_argument("--shape", type=str, default="(64, 32, 16)") + parser.add_argument("--dim", type=int, default=0) + parser.add_argument("--descending", action="store_true") + args = parser.parse_args() + + shape = tuple(map(int, args.shape.strip("()").split(","))) + + device = torch.device("npu:0") + + test_sort_stable_suite(device) + test_sort_duplicate_cases(device) \ No newline at end of file diff --git a/tests/test_sparse_core.py b/tests/test_sparse_core.py index 72eda0c8..bb4ff630 100644 --- a/tests/test_sparse_core.py +++ b/tests/test_sparse_core.py @@ -80,9 +80,6 @@ def test_sparse_mlp(device, batch_size=32, input_size=128, hidden_size=128, outp import os import sys sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) - from Scheduler.scheduler import PyTorchSimRunner - - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_sparse_mlp(device, batch_size=8, input_size=16, hidden_size=32, output_size=64) diff --git a/tests/test_sparsity.py b/tests/test_sparsity.py index a2493673..eaa7c63c 100644 --- a/tests/test_sparsity.py +++ b/tests/test_sparsity.py @@ -96,9 +96,7 @@ def test_mlp_inf(device, batch_size=64, input_size=64, hidden_size=32, output_si ) args = parser.parse_args() - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") #test_dec_inf(device, sparsity=args.sparsity, block=args.block) test_mlp_inf(device, batch_size=32, input_size=784, hidden_size=512, output_size=256, sparsity=args.sparsity, block=args.block) diff --git a/tests/test_spmm_scheduler.py b/tests/test_spmm_scheduler.py deleted file mode 100644 index 71594eb2..00000000 --- a/tests/test_spmm_scheduler.py +++ /dev/null @@ -1,66 +0,0 @@ -import os -import sys -import torch -import argparse -sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) -from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request -from test_sparse_core import SparseMLP as model1 -from test_transformer import EncoderBlock as model2 -CONFIG_TORCHSIM_DIR = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="") - parser.add_argument("--batch_size", type=int, default=128, help="Batch size") - parser.add_argument("--input_size", type=int, default=128, help="Input layer size") - parser.add_argument("--hidden_size", type=int, default=128, help="Hidden layer size") - parser.add_argument("--output_size", type=int, default=128, help="Output layer size") - parser.add_argument("--w1_sparsity", type=float, default=0.5, help="Sparsity of first layer weights (0 to 1)") - parser.add_argument("--w2_sparsity", type=float, default=0.5, help="Sparsity of second layer weights (0 to 1)") - parser.add_argument("--config", type=str) - args = parser.parse_args() - - batch_size = args.batch_size - input_size = args.input_size - hidden_size = args.hidden_size - output_size = args.output_size - w1_sparsity = args.w1_sparsity - w2_sparsity = args.w2_sparsity - config_path = f"{CONFIG_TORCHSIM_DIR}/configs/{args.config}" - - print("batch_size: ", batch_size) - print("input_size: ", input_size) - print("hidden_size: ", hidden_size) - print("output_size: ", output_size) - print("w1_sparsity: ", w1_sparsity) - print("w2_sparsity: ", w2_sparsity) - - with torch.no_grad(): - # Init scheduler - scheduler = Scheduler(num_request_queue=2, engine_select=Scheduler.FIFO_ENGINE, - togsim_config=config_path) - - target_model1 = model1(input_size, hidden_size, output_size, w1_sparsity, w2_sparsity, scheduler.execution_engine.module.custom_device()).eval() - target_model2 = model2(768, 12).eval() - - # Register compiled model - opt_model1 = torch.compile(target_model1.to(device=scheduler.execution_engine.module.custom_device())) - opt_model2 = torch.compile(target_model2.to(device=scheduler.execution_engine.module.custom_device())) - SchedulerDNNModel.register_model("mlp", opt_model1) - SchedulerDNNModel.register_model("bert", opt_model2) - - # Init input data - model_input1 = torch.randn(batch_size, input_size) - model_input2 = torch.randn(1, 512, 768) - - # Init request - new_request1 = Request("mlp", [model_input1], [], request_queue_idx=0) - #new_request2 = Request("bert", [model_input2], [], request_queue_idx=1) - - - # Add request to scheduler - scheduler.add_request(new_request1, request_time=0) - #scheduler.add_request(new_request2, request_time=0) - - # Run scheduler - while not scheduler.is_finished(): - scheduler.schedule() \ No newline at end of file diff --git a/tests/test_stonne.py b/tests/test_stonne.py index 04ad05a8..ac26c273 100644 --- a/tests/test_stonne.py +++ b/tests/test_stonne.py @@ -54,7 +54,5 @@ def test_sparse_mm(device, input_size=128, hidden_size=128, output_size=128, spa args = parser.parse_args() sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_sparse_mm(device, args.sz, args.sz, args.sz, args.sparsity) \ No newline at end of file diff --git a/tests/test_topk.py b/tests/test_topk.py index 0d5c08ec..caf56779 100644 --- a/tests/test_topk.py +++ b/tests/test_topk.py @@ -31,24 +31,11 @@ def topk_fn(a): opt_topk = torch.compile(dynamic=False)(topk_fn) res_values, res_indices = opt_topk(x) - ref_values, ref_indices = torch.topk(x.cpu(), k, dim=dim, largest=largest, sorted=sorted) test_result("TopK/values", res_values, ref_values) test_result("TopK/indices", res_indices, ref_indices) if __name__ == "__main__": - import os - import sys - import argparse - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") - parser.add_argument('--shape', type=str, default="(512,768)") - args = parser.parse_args() - shape = tuple(map(int, args.shape.strip('()').split(','))) - - from Scheduler.scheduler import ExecutionEngine - module = ExecutionEngine.setup_device() - device = module.custom_device() + device = torch.device('npu:0') test_topk(device, (128, 128), k=2, dim=-1) \ No newline at end of file diff --git a/tests/test_transcendental.py b/tests/test_transcendental.py index 38c2f4f6..34546539 100644 --- a/tests/test_transcendental.py +++ b/tests/test_transcendental.py @@ -63,19 +63,14 @@ def cos(a): test_result("Cos", res, out) if __name__ == "__main__": - import os - import sys import argparse - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") parser.add_argument('--shape', type=str, default="(512,768)") args = parser.parse_args() shape = tuple(map(int, args.shape.strip('()').split(','))) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_tanh(device) test_exp(device) test_erf(device) diff --git a/tests/test_transformer.py b/tests/test_transformer.py index a3ac55d7..2b7f308c 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -1,8 +1,6 @@ import math import copy import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -115,13 +113,7 @@ def test_MHA(device, num_heads=12, embed_dim=768, input_seq=512): test_result("MHA Forward", res, cpu_res) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_EncoderBlock(device) # test_Attention(device, head=16, seq=512, d_k=64) # test_MHA(device, num_heads=12, embed_dim=768) diff --git a/tests/test_transpose2D.py b/tests/test_transpose2D.py index af5aacf7..4e9807ce 100644 --- a/tests/test_transpose2D.py +++ b/tests/test_transpose2D.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -42,13 +40,7 @@ def transpose(a, b): test_result("Transpose2 Forward", res, out) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_Transpose2D(device, [64, 156]) test_Transpose2D_2(device, [16, 64]) test_Transpose2D(device, [640, 256]) diff --git a/tests/test_transpose3D.py b/tests/test_transpose3D.py index d6c1092d..e4d4e952 100644 --- a/tests/test_transpose3D.py +++ b/tests/test_transpose3D.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -57,13 +55,7 @@ def transpose(a, b): test_result("Transpose 3D Forward", res, out) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_Transpose3D_1(device, [62, 34, 44]) test_Transpose3D_1(device, [62, 134, 144]) test_Transpose3D_2(device, [62, 34, 44]) diff --git a/tests/test_vectorops.py b/tests/test_vectorops.py index ed895171..90e9c0f5 100644 --- a/tests/test_vectorops.py +++ b/tests/test_vectorops.py @@ -1,14 +1,7 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") # Target shape seq_list = [1,128,512,2048,8192] diff --git a/tests/test_view3D_2D.py b/tests/test_view3D_2D.py index 148fe8fa..cc7b5e41 100644 --- a/tests/test_view3D_2D.py +++ b/tests/test_view3D_2D.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -40,13 +38,7 @@ def view2D_3D(a): test_result("view 2D->3D", res, out) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_view3D_2D(device) test_view3D_2D(device, [12, 512, 64]) test_view2D_3D(device, size=(512, 1024), h=16, d_k=64) diff --git a/tests/test_vit.py b/tests/test_vit.py index aeb4f148..6149166d 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -202,9 +202,7 @@ def test_encoder_block_with_class_token( shape = tuple(map(int, args.shape.strip('()').split(','))) sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") #test_multihead_attention(device) #test_encoder_block(device, seq_len=197) #test_encoder_block_with_class_token(device, seq_len=196) diff --git a/thirdparty/github-releases.json b/thirdparty/github-releases.json new file mode 100644 index 00000000..25c220c9 --- /dev/null +++ b/thirdparty/github-releases.json @@ -0,0 +1,19 @@ +{ + "description": "GitHub release pins for CI (docker base image). pytorch_image is the ARG PYTORCH_IMAGE for Dockerfile.base. Use release_tag \"latest\" or an exact release tag for GitHub deps. asset_name must match the release attachment filename. CI builds ghcr.io/.../torchsim_base:thirdparty-<12 hex> when missing (pin = sha256 of this file plus Dockerfile.base) and updates :latest on that push.", + "pytorch_image": "pytorch/pytorch:2.8.0-cuda12.6-cudnn9-devel", + "gem5": { + "repository": "PSAL-POSTECH/gem5", + "release_tag": "v1.0.1", + "asset_name": "gem5-release.tar.gz" + }, + "llvm_project": { + "repository": "PSAL-POSTECH/llvm-project", + "release_tag": "v1.0.6", + "asset_name": "riscv-llvm-release.tar.gz" + }, + "spike": { + "repository": "PSAL-POSTECH/riscv-isa-sim", + "release_tag": "v1.0.1", + "asset_name": "spike-release.tar.gz" + } +} diff --git a/tutorial/jupyterhub/Dockerfile b/tutorial/jupyterhub/Dockerfile new file mode 100644 index 00000000..f98b2294 --- /dev/null +++ b/tutorial/jupyterhub/Dockerfile @@ -0,0 +1,7 @@ +FROM jupyterhub/jupyterhub:latest + +RUN pip install --no-cache-dir \ + dockerspawner \ + jupyterhub-nativeauthenticator + +WORKDIR /srv/jupyterhub diff --git a/Dockerfile.ksc2025 b/tutorial/jupyterhub/Dockerfile.ksc2025 similarity index 94% rename from Dockerfile.ksc2025 rename to tutorial/jupyterhub/Dockerfile.ksc2025 index 2ac210e0..7633c048 100644 --- a/Dockerfile.ksc2025 +++ b/tutorial/jupyterhub/Dockerfile.ksc2025 @@ -33,15 +33,15 @@ RUN apt -y update && apt -y upgrade && \ python3-dev python-is-python3 doxygen libboost-all-dev \ libhdf5-serial-dev python3-pydot libpng-dev libelf-dev pkg-config pip \ python3-venv black libssl-dev libasan5 libubsan1 -RUN pip install mypy pre-commit jupyter +RUN pip install mypy pre-commit jupyter pydot tabulate jupyterlab_execute_time # Pass Access Token securely ENV PATH=$PATH:/root/.local/bin ENV LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:/opt/conda/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:$LD_LIBRARY_PATH # Build Gem5 -RUN git clone https://github.com/PSAL-POSTECH/gem5.git --branch TorchSim -RUN cd gem5 && scons build/RISCV/gem5.opt -j $(nproc) +RUN git clone https://github.com/PSAL-POSTECH/gem5.git --branch tutorial +RUN cd gem5 && scons build/RISCV/gem5.opt -j $(nproc) && git checkout TorchSim ENV GEM5_PATH=/workspace/gem5/build/RISCV/gem5.opt # Build LLVM RISC-V @@ -52,9 +52,7 @@ RUN cd llvm-project && mkdir build && cd build && \ # Store RISC-V LLVM for TorchSim ENV TORCHSIM_LLVM_PATH=/riscv-llvm/bin -ENV TORCHSIM_LLVM_INCLUDE_PATH=/riscv-llvm/include ENV TORCHSIM_DIR=/workspace/PyTorchSim -ENV LLVM_DIR=/riscv-llvm # Download RISC-V tool chain RUN apt install -y wget && \ @@ -79,7 +77,7 @@ RUN git clone https://github.com/riscv-software-src/riscv-pk.git && \ # Install torchsim dependency RUN apt install ninja-build && pip install onnx matplotlib && pip install --user conan==1.56.0 -# Prepare ONNXim project +# Prepare PyTorchSim project RUN git clone https://github.com/PSAL-POSTECH/PyTorchSim.git --branch tutorial RUN cd PyTorchSim/TOGSim && \ git submodule update --recursive --init && \ @@ -87,4 +85,6 @@ RUN cd PyTorchSim/TOGSim && \ cd build && \ conan install .. --build=missing && \ cmake .. && \ - make -j$(nproc) \ No newline at end of file + make -j$(nproc) + +RUN pip install jupyterhub jupyterlab diff --git a/tutorial/jupyterhub/docker-compose.yml b/tutorial/jupyterhub/docker-compose.yml new file mode 100644 index 00000000..62c07ff1 --- /dev/null +++ b/tutorial/jupyterhub/docker-compose.yml @@ -0,0 +1,25 @@ +version: '3' + +services: + jupyterhub: + build: + context: . + dockerfile: Dockerfile + container_name: jupyterhub + image: my-jupyterhub-image + volumes: + - /var/run/docker.sock:/var/run/docker.sock + - ./jupyterhub_config.py:/srv/jupyterhub/jupyterhub_config.py + environment: + # DockerSpawner가 사용할 네트워크 이름 + DOCKER_NETWORK_NAME: jupyterhub-network + # Hub가 내부적으로 사용할 IP + HUB_IP: jupyterhub + ports: + - "8888:8000" + networks: + - jupyterhub-network + +networks: + jupyterhub-network: + external: true diff --git a/tutorial/jupyterhub/jupyterhub_config.py b/tutorial/jupyterhub/jupyterhub_config.py new file mode 100644 index 00000000..a43c0543 --- /dev/null +++ b/tutorial/jupyterhub/jupyterhub_config.py @@ -0,0 +1,28 @@ +import os + +c = get_config() + +# ------------------------------------------------------------------------------ +# Spawner config +# ------------------------------------------------------------------------------ +c.JupyterHub.spawner_class = 'dockerspawner.DockerSpawner' +c.DockerSpawner.image = "ghcr.io/psal-postech/torchsim_ksc2025:latest" + +# Resource limit +c.DockerSpawner.mem_limit = '16G' +c.DockerSpawner.cpu_limit = 4.0 + +c.DockerSpawner.network_name = 'jupyterhub-network' +c.Spawner.default_url = '/lab' +c.Spawner.ip = '0.0.0.0' +c.DockerSpawner.remove = False +c.DockerSpawner.cmd = ["jupyterhub-singleuser", "--allow-root"] + +c.JupyterHub.authenticator_class = 'nativeauthenticator.NativeAuthenticator' +c.Authenticator.admin_users = {'admin'} + +c.JupyterHub.hub_ip = 'jupyterhub' +c.JupyterHub.hub_port = 8081 + +c.NativeAuthenticator.open_signup = True +c.NativeAuthenticator.allow_all = True diff --git a/tutorial/jupyterhub/setting.sh b/tutorial/jupyterhub/setting.sh new file mode 100755 index 00000000..3e544839 --- /dev/null +++ b/tutorial/jupyterhub/setting.sh @@ -0,0 +1,5 @@ +if [ -z "$(docker network ls | grep jupyterhub-network)" ]; then + docker network create jupyterhub-network +fi + +docker compose up -d --build \ No newline at end of file diff --git a/tutorial/session1/CompilerOptimization.ipynb b/tutorial/session1/CompilerOptimization.ipynb index 178974c1..d17a6b25 100644 --- a/tutorial/session1/CompilerOptimization.ipynb +++ b/tutorial/session1/CompilerOptimization.ipynb @@ -18,7 +18,7 @@ "import sys\n", "base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')\n", "sys.path.append(base_dir)\n", - "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_timing_only.json\"" + "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_timing_only.yml\"" ] }, { @@ -35,8 +35,7 @@ "outputs": [], "source": [ "os.environ['TORCHSIM_DUMP_PATH']=os.path.join(os.getcwd(), \"fused\")\n", - "from Scheduler.scheduler import PyTorchSimRunner\n", - "device = PyTorchSimRunner.setup_device().custom_device()\n", + "device = torch.device(\"npu:0\")\n", "\n", "input = torch.randn(1024, 1024).to(device=device)\n", "weight = torch.randn(1024, 1024).to(device=device)\n", @@ -71,7 +70,7 @@ "outputs": [], "source": [ "os.environ['TORCHSIM_DUMP_PATH']=os.path.join(os.getcwd(), \"non_fused\")\n", - "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_no_compiler_optimization.json\"\n", + "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_no_compiler_optimization.yml\"\n", "\n", "input = torch.randn(1024, 1024).to(device=device)\n", "weight = torch.randn(1024, 1024).to(device=device)\n", diff --git a/tutorial/session1/DNNServing.ipynb b/tutorial/session1/DNNServing.ipynb index b38bfe6a..741f463f 100644 --- a/tutorial/session1/DNNServing.ipynb +++ b/tutorial/session1/DNNServing.ipynb @@ -38,7 +38,7 @@ "from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request\n", "from PyTorchSimFrontend import extension_config\n", "\n", - "scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=extension_config.TOGSIM_CONFIG)\n", + "scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=extension_config.CONFIG_TOGSIM_CONFIG)\n", "device = scheduler.execution_engine.module.custom_device()\n", "\n", "model = resnet18().eval()\n", @@ -83,7 +83,7 @@ "target_model1 = resnet18().eval()\n", "\n", "# Init scheduler\n", - "scheduler = Scheduler(num_request_queue=1, max_batch=32, engine_select=Scheduler.FIFO_ENGINE, togsim_config=extension_config.TOGSIM_CONFIG)\n", + "scheduler = Scheduler(num_request_queue=1, max_batch=32, engine_select=Scheduler.FIFO_ENGINE, togsim_config=extension_config.CONFIG_TOGSIM_CONFIG)\n", "# Register compiled model\n", "opt_model1 = torch.compile(target_model1.to(device=scheduler.execution_engine.module.custom_device(), memory_format=torch.channels_last), dynamic=False)\n", "SchedulerDNNModel.register_model(\"resnet18\", opt_model1)\n", diff --git a/tutorial/session1/ExecutionMode.ipynb b/tutorial/session1/ExecutionMode.ipynb index 22e00bed..d94323db 100644 --- a/tutorial/session1/ExecutionMode.ipynb +++ b/tutorial/session1/ExecutionMode.ipynb @@ -33,8 +33,7 @@ "metadata": {}, "outputs": [], "source": [ - "from Scheduler.scheduler import PyTorchSimRunner\n", - "device = PyTorchSimRunner.setup_device().custom_device()\n", + "device = torch.device(\"npu:0\")\n", "\n", "input = torch.randn(1024, 1024).to(device=device)\n", "weight = torch.randn(1024, 1024).to(device=device)\n", @@ -56,7 +55,7 @@ "metadata": {}, "outputs": [], "source": [ - "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_functional_only.json\"\n", + "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_functional_only.yml\"\n", "\n", "input = torch.randn(1024, 1024).to(device=device)\n", "weight = torch.randn(1024, 1024).to(device=device)\n", @@ -78,7 +77,7 @@ "metadata": {}, "outputs": [], "source": [ - "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_timing_only.json\"\n", + "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_timing_only.yml\"\n", "\n", "input = torch.randn(1024, 1024).to(device=device)\n", "weight = torch.randn(1024, 1024).to(device=device)\n", @@ -101,7 +100,7 @@ "metadata": {}, "outputs": [], "source": [ - "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_timing_only.json\"\n", + "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_timing_only.yml\"\n", "\n", "input = torch.randn(2048, 2048).to(device=device)\n", "weight = torch.randn(2048, 2048).to(device=device)\n", @@ -132,7 +131,7 @@ "metadata": {}, "outputs": [], "source": [ - "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_2_cores.json\"\n", + "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_2_cores.yml\"\n", "\n", "input = torch.randn(2048, 2048).to(device=device)\n", "weight = torch.randn(2048, 2048).to(device=device)\n", diff --git a/tutorial/session1/Inference.ipynb b/tutorial/session1/Inference.ipynb index a49e2440..6fd54aed 100644 --- a/tutorial/session1/Inference.ipynb +++ b/tutorial/session1/Inference.ipynb @@ -57,8 +57,7 @@ "metadata": {}, "outputs": [], "source": [ - "from Scheduler.scheduler import PyTorchSimRunner\n", - "device = PyTorchSimRunner.setup_device().custom_device()\n", + "device = torch.device(\"npu:0\")\n", "\n", "torch.manual_seed(0)\n", "input = torch.randn(128, 128).to(device)\n", diff --git a/tutorial/session1/LogAnalysis.ipynb b/tutorial/session1/LogAnalysis.ipynb index 4f1e17cb..24dae52b 100644 --- a/tutorial/session1/LogAnalysis.ipynb +++ b/tutorial/session1/LogAnalysis.ipynb @@ -18,8 +18,8 @@ "import sys\n", "base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')\n", "sys.path.append(base_dir)\n", - "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_timing_only.json\"\n", - "os.environ['TORCHSIM_DUMP_LOG_PATH']=os.path.join(os.getcwd(), \"togsim_results\")" + "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_timing_only.yml\"\n", + "os.environ['TORCHSIM_LOG_PATH']=os.path.join(os.getcwd(), \"togsim_results\")" ] }, { @@ -35,8 +35,7 @@ "metadata": {}, "outputs": [], "source": [ - "from Scheduler.scheduler import PyTorchSimRunner\n", - "device = PyTorchSimRunner.setup_device().custom_device()\n", + "device = torch.device(\"npu:0\")\n", "\n", "input = torch.randn(1024, 1024).to(device=device)\n", "weight = torch.randn(1024, 1024).to(device=device)\n", diff --git a/tutorial/session1/Mapping.ipynb b/tutorial/session1/Mapping.ipynb index b02c98fe..0b978bcb 100644 --- a/tutorial/session1/Mapping.ipynb +++ b/tutorial/session1/Mapping.ipynb @@ -33,8 +33,7 @@ "metadata": {}, "outputs": [], "source": [ - "from Scheduler.scheduler import PyTorchSimRunner\n", - "device = PyTorchSimRunner.setup_device().custom_device()\n", + "device = torch.device(\"npu:0\")\n", "\n", "input = torch.randn(1024, 1024).to(device=device)\n", "weight = torch.randn(1024, 1024).to(device=device)\n", @@ -68,7 +67,7 @@ "source": [ "torch._dynamo.reset()\n", "\n", - "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_external_mapping.json\"\n", + "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_external_mapping.yml\"\n", "\n", "input = torch.randn(1024, 1024).to(device=device)\n", "weight = torch.randn(1024, 1024).to(device=device)\n", @@ -101,7 +100,7 @@ "source": [ "torch._dynamo.reset()\n", "\n", - "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_autotune.json\"\n", + "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_autotune.yml\"\n", "\n", "input = torch.randn(1024, 1024).to(device=device)\n", "weight = torch.randn(1024, 1024).to(device=device)\n", diff --git a/tutorial/session1/Training.ipynb b/tutorial/session1/Training.ipynb index 0c6b138a..badf7ed7 100644 --- a/tutorial/session1/Training.ipynb +++ b/tutorial/session1/Training.ipynb @@ -20,8 +20,7 @@ "sys.path.append(base_dir)\n", "\n", "cpu_device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "from Scheduler.scheduler import PyTorchSimRunner\n", - "npu_device = PyTorchSimRunner.setup_device().custom_device()" + "npu_device = torch.device(\"npu:0\")" ] }, { diff --git a/tutorial/session1/togsim_configs/togsim_config.json b/tutorial/session1/togsim_configs/togsim_config.json deleted file mode 100644 index e8e489d9..00000000 --- a/tutorial/session1/togsim_configs/togsim_config.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 16, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "heuristic", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/tutorial/session1/togsim_configs/togsim_config.yml b/tutorial/session1/togsim_configs/togsim_config.yml new file mode 100644 index 00000000..72873f1c --- /dev/null +++ b/tutorial/session1/togsim_configs/togsim_config.yml @@ -0,0 +1,30 @@ +num_cores: 1 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 16 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: heuristic +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/tutorial/session1/togsim_configs/togsim_config_2_cores.json b/tutorial/session1/togsim_configs/togsim_config_2_cores.json deleted file mode 100644 index c50edaa9..00000000 --- a/tutorial/session1/togsim_configs/togsim_config_2_cores.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "num_cores" : 2, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 32, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 0, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "heuristic", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/tutorial/session1/togsim_configs/togsim_config_2_cores.yml b/tutorial/session1/togsim_configs/togsim_config_2_cores.yml new file mode 100644 index 00000000..3b9b8fc8 --- /dev/null +++ b/tutorial/session1/togsim_configs/togsim_config_2_cores.yml @@ -0,0 +1,30 @@ +num_cores: 2 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 32 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 0 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: heuristic +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/tutorial/session1/togsim_configs/togsim_config_autotune.json b/tutorial/session1/togsim_configs/togsim_config_autotune.json deleted file mode 100644 index c9763e92..00000000 --- a/tutorial/session1/togsim_configs/togsim_config_autotune.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 16, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "autotune", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/tutorial/session1/togsim_configs/togsim_config_autotune.yml b/tutorial/session1/togsim_configs/togsim_config_autotune.yml new file mode 100644 index 00000000..2726736a --- /dev/null +++ b/tutorial/session1/togsim_configs/togsim_config_autotune.yml @@ -0,0 +1,30 @@ +num_cores: 1 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 16 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: autotune +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/tutorial/session1/togsim_configs/togsim_config_external_mapping.json b/tutorial/session1/togsim_configs/togsim_config_external_mapping.json deleted file mode 100644 index c8ddb0f3..00000000 --- a/tutorial/session1/togsim_configs/togsim_config_external_mapping.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 16, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "external-then-heuristic", - "codegen_external_mapping_file" : "/workspace/PyTorchSim/tutorial/session1/tutorial_external_mapping.json", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/tutorial/session1/togsim_configs/togsim_config_external_mapping.yml b/tutorial/session1/togsim_configs/togsim_config_external_mapping.yml new file mode 100644 index 00000000..468a0b44 --- /dev/null +++ b/tutorial/session1/togsim_configs/togsim_config_external_mapping.yml @@ -0,0 +1,30 @@ +num_cores: 1 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 16 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: external-then-heuristic +codegen_external_mapping_file: /workspace/PyTorchSim/tutorial/session1/tutorial_external_mapping.json +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/tutorial/session1/togsim_configs/togsim_config_functional_only.json b/tutorial/session1/togsim_configs/togsim_config_functional_only.json deleted file mode 100644 index 53072307..00000000 --- a/tutorial/session1/togsim_configs/togsim_config_functional_only.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 16, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 1, - "pytorchsim_timing_mode" : 0, - - "codegen_mapping_strategy" : "heuristic", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/tutorial/session1/togsim_configs/togsim_config_functional_only.yml b/tutorial/session1/togsim_configs/togsim_config_functional_only.yml new file mode 100644 index 00000000..a1f1b432 --- /dev/null +++ b/tutorial/session1/togsim_configs/togsim_config_functional_only.yml @@ -0,0 +1,30 @@ +num_cores: 1 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 16 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 1 +pytorchsim_timing_mode: 0 + +codegen_mapping_strategy: heuristic +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/tutorial/session1/togsim_configs/togsim_config_no_compiler_optimization.json b/tutorial/session1/togsim_configs/togsim_config_no_compiler_optimization.json deleted file mode 100644 index e2b9c8c8..00000000 --- a/tutorial/session1/togsim_configs/togsim_config_no_compiler_optimization.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 16, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 0, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "heuristic", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "none" -} \ No newline at end of file diff --git a/tutorial/session1/togsim_configs/togsim_config_no_compiler_optimization.yml b/tutorial/session1/togsim_configs/togsim_config_no_compiler_optimization.yml new file mode 100644 index 00000000..62d627a6 --- /dev/null +++ b/tutorial/session1/togsim_configs/togsim_config_no_compiler_optimization.yml @@ -0,0 +1,30 @@ +num_cores: 1 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 16 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 0 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: heuristic +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: none diff --git a/tutorial/session1/togsim_configs/togsim_config_timing_only.json b/tutorial/session1/togsim_configs/togsim_config_timing_only.json deleted file mode 100644 index 0b846bbd..00000000 --- a/tutorial/session1/togsim_configs/togsim_config_timing_only.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "num_cores" : 1, - "core_freq_mhz" : 940, - "core_stats_print_period_cycles" : 10000, - "num_systolic_array_per_core" : 2, - - "vpu_num_lanes" : 128, - "vpu_spad_size_kb_per_lane" : 128, - "vpu_vector_length_bits" : 256, - - "dram_type" : "ramulator2", - "dram_freq_mhz" : 940, - "dram_channels": 16, - "dram_req_size_byte": 32, - "dram_num_burst_length" : 2, - "dram_stats_print_period_cycles": 10000, - "ramulator_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", - - "icnt_type" : "simple", - "icnt_latency_cycles" : 10, - "icnt_freq_mhz" : 940, - "icnt_injection_ports_per_core" : 16, - - "pytorchsim_functional_mode" : 0, - "pytorchsim_timing_mode" : 1, - - "codegen_mapping_strategy" : "heuristic", - "codegen_external_mapping_file" : "", - "codegen_autotune_max_retry": 10, - "codegen_autotune_template_topk": 4, - "codegen_compiler_optimization" : "all" -} \ No newline at end of file diff --git a/tutorial/session1/togsim_configs/togsim_config_timing_only.yml b/tutorial/session1/togsim_configs/togsim_config_timing_only.yml new file mode 100644 index 00000000..0024c073 --- /dev/null +++ b/tutorial/session1/togsim_configs/togsim_config_timing_only.yml @@ -0,0 +1,30 @@ +num_cores: 1 +core_freq_mhz: 940 +core_stats_print_period_cycles: 10000 +num_systolic_array_per_core: 2 + +vpu_num_lanes: 128 +vpu_spad_size_kb_per_lane: 128 +vpu_vector_length_bits: 256 + +dram_type: ramulator2 +dram_freq_mhz: 940 +dram_channels: 16 +dram_req_size_byte: 32 +dram_num_burst_length: 2 +dram_stats_print_period_cycles: 10000 +ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml + +icnt_type: simple +icnt_latency_cycles: 10 +icnt_freq_mhz: 940 +icnt_injection_ports_per_core: 16 + +pytorchsim_functional_mode: 0 +pytorchsim_timing_mode: 1 + +codegen_mapping_strategy: heuristic +codegen_external_mapping_file: '' +codegen_autotune_max_retry: 10 +codegen_autotune_template_topk: 4 +codegen_compiler_optimization: all diff --git a/tutorial/session2/Hands_on.ipynb b/tutorial/session2/Hands_on.ipynb index 33ec1a28..9a7c35e3 100644 --- a/tutorial/session2/Hands_on.ipynb +++ b/tutorial/session2/Hands_on.ipynb @@ -32,11 +32,10 @@ "import torch._dynamo\n", "import torch.utils.cpp_extension\n", "base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')\n", + "os.environ['TORCHSIM_LOG_PATH']=os.path.join(os.getcwd(), \"togsim_results\")\n", "sys.path.append(base_dir)\n", "\n", - "from Scheduler.scheduler import PyTorchSimRunner\n", - "module = PyTorchSimRunner.setup_device()\n", - "device = module.custom_device()\n", + "device = torch.device(\"npu:0\")\n", "\n", "def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4):\n", " if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol):\n", diff --git a/tutorial/session2/Warmup.py b/tutorial/session2/Warmup.py index ce215cf5..a45734ad 100644 --- a/tutorial/session2/Warmup.py +++ b/tutorial/session2/Warmup.py @@ -1,13 +1,19 @@ from typing import List import os from torch.fx.passes.graph_drawer import FxGraphDrawer -os.environ['TORCH_LOGS'] = 'bytecode' import torch +import inspect def dummy_compiler(gm: torch.fx.GraphModule, _): - gm.graph.print_tabular() + sep = "-" * 80 drawer = FxGraphDrawer(gm, "my_model") drawer.get_dot_graph().write_svg("fx_graph.svg") + + print(f"\n{sep}\n[1] FX Graph Tabular View\n{sep}") + gm.graph.print_tabular() + + print(f"\n{sep}\n[2] Generated Forward Source Code\n{sep}") + print(inspect.getsource(gm.forward)) return gm.forward # Return a callable object class MyModel(torch.nn.Module): @@ -23,5 +29,4 @@ def f(x, y): if __name__ == "__main__": x = torch.randn(7, 5,requires_grad=False) y = torch.randn(5, 3,requires_grad=False) - k = f(x, y) - print(k) + k = f(x, y) \ No newline at end of file