I have a piece of code like this:
def convert(tir_expr: tir.expr.PrimExpr) -> expr.Expr:
if isinstance(tir_expr, tir.expr.Add):
return convert_Add(tir_expr)
elif isinstance(tir_expr, tir.expr.Sub):
return convert_Sub(tir_expr)
And I have a lot of convert_XXX cases, and I'd like to get the following result, i.e. repeatedly replacing all XXX with some YYY:
def convert(tir_expr: tir.expr.PrimExpr) -> expr.Expr:
if isinstance(tir_expr, tir.expr.Add):
return convert_Add(tir_expr)
elif isinstance(tir_expr, tir.expr.Sub):
return convert_Sub(tir_expr)
elif isinstance(tir_expr, tir.expr.Mul):
return convert_Mul(tir_expr)
elif isinstance(tir_expr, tir.expr.Div):
return convert_Div(tir_expr)
elif isinstance(tir_expr, tir.expr.IntImm):
return convert_IntImm(tir_expr)
elif isinstance(tir_expr, tir.expr.FloatImm):
return convert_FloatImm(tir_expr)
elif isinstance(tir_expr, tir.expr.StringImm):
return convert_StringImm(tir_expr)
I don't know how to achieve what I want and how to search an answer for that. Any tips?
elif isinstance(tir_expr, tir.expr.Sub): return convert_Sub(tir_expr)
? – Tobias May 27 '21 at 05:04