Skip to content

[Bug] remove_unused_args incompatible with VarUseDefAnalysis in make_packed_api #98

@qelk123

Description

@qelk123

Expected behavior

params removed in remove_unused_args pass shouldn't be checked in make_packed_api pass

Actual behavior

In primfunc with dynamic behavior, Variable args used in describing an axis metadata (shape of indices buffer, indptr buffer or data buffer) are removed in remove_unused_args pass, but are required in VarUseDefAnalysis of make_packed_api pass.

Case

TVM Script kernel:

@T.prim_func
def csrmm(
    a: T.handle,
    b: T.handle,
    c: T.handle,
    indptr: T.handle,
    indices: T.handle,
    m: T.int32,
    n: T.int32,
    num_tiles: T.int32,
    nnz: T.int32,
    cwm: T.int32,
) -> None:
    T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 2})
    I = T.dense_fixed(m)
    J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32")
    J_detach = T.dense_fixed(n)
    K1 = T.dense_fixed(num_tiles)
    K2 = T.dense_fixed(cwm)
    K3 = T.dense_fixed(32)
    A = T.match_sparse_buffer(a, (I, J), "float32")
    B = T.match_sparse_buffer(b, (J_detach, K1, K2, K3), "float32")
    C = T.match_sparse_buffer(c, (I, K1, K2, K3), "float32")
    with T.sp_iter([I, J, K1, K2, K3], "SRSSS", "csrmm") as [i, j, k1, k2, k3]:
        with T.init():
            C[i, k1, k2, k3] = 0.0
        C[i, k1, k2, k3] = C[i, k1, k2, k3] + A[i, j] * B[j, k1, k2, k3]

Primfunc before make_packed_api (decompose format with 5 buckets and 2 tile blocks)

primfn(b: handle, c: handle, num_tiles: int32, cwm: int32, a_0_0: handle, indices_i_0_0: handle, indices_j_0_0: handle, num_rows_0_0: int32, a_0_1: handle, indices_i_0_1: handle, indices_j_0_1: handle, num_rows_0_1: int32, a_0_2: handle, indices_i_0_2: handle, indices_j_0_2: handle, num_rows_0_2: int32, a_0_3: handle, indices_i_0_3: handle, indices_j_0_3: handle, num_rows_0_3: int32, a_0_4: handle, indices_i_0_4: handle, indices_j_0_4: handle, num_rows_0_4: int32, a_0_5: handle, indices_i_0_5: handle, indices_j_0_5: handle, num_rows_0_5: int32, a_1_0: handle, indices_i_1_0: handle, indices_j_1_0: handle, num_rows_1_0: int32, a_1_1: handle, indices_i_1_1: handle, indices_j_1_1: handle, num_rows_1_1: int32, a_1_2: handle, indices_i_1_2: handle, indices_j_1_2: handle, num_rows_1_2: int32, a_1_3: handle, indices_i_1_3: handle, indices_j_1_3: handle, num_rows_1_3: int32, a_1_4: handle, indices_i_1_4: handle, indices_j_1_4: handle, num_rows_1_4: int32, a_1_5: handle, indices_i_1_5: handle, indices_j_1_5: handle, num_rows_1_5: int32) -> ()
  attr = {"target": Target(kind='cuda', keys={'cuda', 'gpu'}, attrs={'thread_warp_size': 32, 'max_num_threads': 1024, 'arch': "sm_75"}, host=Target(kind='llvm', keys={'cpu'}, attrs={'link-params': (bool)0})), "tir.noalias": True, "global_symbol": "main", "composable": 1, "sparse_tir_level": 0, "tir.is_entry_func": True}
  buffers = {B_data: Buffer(B: Pointer(global float32), float32, [(((n: int32*num_tiles)*cwm)*32)], []),
             C_data: Buffer(C: Pointer(global float32), float32, [(((m: int32*num_tiles)*cwm)*32)], []),
             A_0_0_data: Buffer(A_0_0: Pointer(global float32), float32, [num_rows_0_0], []),
             I_0_0_indices_data: Buffer(I_0_0_indices.data: Pointer(global int32), int32, [num_rows_0_0], []),
             J_0_0_indices_data: Buffer(J_0_0_indices.data: Pointer(global int32), int32, [num_rows_0_0], []),
             ...
  buffer_map = {...} {
	...
  }

From this example we can see m and n are removed from arg list, since they are only used in describing the shape of B_data buffer and C_data and some block read region and write region.All of these place are not included in remove_unused_args pass, so they are removed from the arg list.

However, these args are checked in VarUseDefAnalysis pass and used to check DLTensor shape of B_data&&C_data, and cause the error message:

TVMError: Not all Vars are passed in api_args: 'n' 'm' is not bound to any variables

possiable solutions

Keep var used in BufferNode shape and stride field in the arg list. We can modify the solution in VarUseDefAnalysis pass for BufferNode.

Also,this solution won't influence current case,since we remove these vars by specializing them before remove_unused_args pass with constant value.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions