-
Notifications
You must be signed in to change notification settings - Fork 14
Description
Expected behavior
params removed in remove_unused_args pass shouldn't be checked in make_packed_api pass
Actual behavior
In primfunc with dynamic behavior, Variable args used in describing an axis metadata (shape of indices buffer, indptr buffer or data buffer) are removed in remove_unused_args pass, but are required in VarUseDefAnalysis of make_packed_api pass.
Case
TVM Script kernel:
@T.prim_func
def csrmm(
a: T.handle,
b: T.handle,
c: T.handle,
indptr: T.handle,
indices: T.handle,
m: T.int32,
n: T.int32,
num_tiles: T.int32,
nnz: T.int32,
cwm: T.int32,
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 2})
I = T.dense_fixed(m)
J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32")
J_detach = T.dense_fixed(n)
K1 = T.dense_fixed(num_tiles)
K2 = T.dense_fixed(cwm)
K3 = T.dense_fixed(32)
A = T.match_sparse_buffer(a, (I, J), "float32")
B = T.match_sparse_buffer(b, (J_detach, K1, K2, K3), "float32")
C = T.match_sparse_buffer(c, (I, K1, K2, K3), "float32")
with T.sp_iter([I, J, K1, K2, K3], "SRSSS", "csrmm") as [i, j, k1, k2, k3]:
with T.init():
C[i, k1, k2, k3] = 0.0
C[i, k1, k2, k3] = C[i, k1, k2, k3] + A[i, j] * B[j, k1, k2, k3]
Primfunc before make_packed_api (decompose format with 5 buckets and 2 tile blocks)
primfn(b: handle, c: handle, num_tiles: int32, cwm: int32, a_0_0: handle, indices_i_0_0: handle, indices_j_0_0: handle, num_rows_0_0: int32, a_0_1: handle, indices_i_0_1: handle, indices_j_0_1: handle, num_rows_0_1: int32, a_0_2: handle, indices_i_0_2: handle, indices_j_0_2: handle, num_rows_0_2: int32, a_0_3: handle, indices_i_0_3: handle, indices_j_0_3: handle, num_rows_0_3: int32, a_0_4: handle, indices_i_0_4: handle, indices_j_0_4: handle, num_rows_0_4: int32, a_0_5: handle, indices_i_0_5: handle, indices_j_0_5: handle, num_rows_0_5: int32, a_1_0: handle, indices_i_1_0: handle, indices_j_1_0: handle, num_rows_1_0: int32, a_1_1: handle, indices_i_1_1: handle, indices_j_1_1: handle, num_rows_1_1: int32, a_1_2: handle, indices_i_1_2: handle, indices_j_1_2: handle, num_rows_1_2: int32, a_1_3: handle, indices_i_1_3: handle, indices_j_1_3: handle, num_rows_1_3: int32, a_1_4: handle, indices_i_1_4: handle, indices_j_1_4: handle, num_rows_1_4: int32, a_1_5: handle, indices_i_1_5: handle, indices_j_1_5: handle, num_rows_1_5: int32) -> ()
attr = {"target": Target(kind='cuda', keys={'cuda', 'gpu'}, attrs={'thread_warp_size': 32, 'max_num_threads': 1024, 'arch': "sm_75"}, host=Target(kind='llvm', keys={'cpu'}, attrs={'link-params': (bool)0})), "tir.noalias": True, "global_symbol": "main", "composable": 1, "sparse_tir_level": 0, "tir.is_entry_func": True}
buffers = {B_data: Buffer(B: Pointer(global float32), float32, [(((n: int32*num_tiles)*cwm)*32)], []),
C_data: Buffer(C: Pointer(global float32), float32, [(((m: int32*num_tiles)*cwm)*32)], []),
A_0_0_data: Buffer(A_0_0: Pointer(global float32), float32, [num_rows_0_0], []),
I_0_0_indices_data: Buffer(I_0_0_indices.data: Pointer(global int32), int32, [num_rows_0_0], []),
J_0_0_indices_data: Buffer(J_0_0_indices.data: Pointer(global int32), int32, [num_rows_0_0], []),
...
buffer_map = {...} {
...
}
From this example we can see m and n are removed from arg list, since they are only used in describing the shape of B_data buffer and C_data and some block read region and write region.All of these place are not included in remove_unused_args pass, so they are removed from the arg list.
However, these args are checked in VarUseDefAnalysis pass and used to check DLTensor shape of B_data&&C_data, and cause the error message:
TVMError: Not all Vars are passed in api_args: 'n' 'm' is not bound to any variables
possiable solutions
Keep var used in BufferNode shape and stride field in the arg list. We can modify the solution in VarUseDefAnalysis pass for BufferNode.
Also,this solution won't influence current case,since we remove these vars by specializing them before remove_unused_args pass with constant value.