diff --git a/jaxlib/mosaic/python/apply_vector_layout.py b/jaxlib/mosaic/python/apply_vector_layout.py index 766bc3edcff3..89c14cb87993 100644 --- a/jaxlib/mosaic/python/apply_vector_layout.py +++ b/jaxlib/mosaic/python/apply_vector_layout.py @@ -878,7 +878,6 @@ def select_tiles_from_rotated_row_vregs( end_src_col: int, first_dst_tile_sublane_offset: int, dst_layout: VectorLayout, - hw_generation: int, ) -> ValueLike: """Assembles a destination tile using partial data from rotated vregs using a divide-and-conquer strategy. @@ -892,7 +891,6 @@ def select_tiles_from_rotated_row_vregs( first_dst_tile_sublane_offset: Sublane offset where the first dst tile to be selected starts. dst_layout: Destination layout, based on which retiling is being performed. - hw_generation: The generation of a target hardware. Returns: A new vreg assembled from dst tiles stored in given rotated vregs. @@ -911,7 +909,6 @@ def select_tiles_from_rotated_row_vregs( mid_src_col, first_dst_tile_sublane_offset, dst_layout, - hw_generation, ) left_tiles_count = mid_src_col - start_src_col + 1 @@ -926,7 +923,6 @@ def select_tiles_from_rotated_row_vregs( end_src_col, right_first_dst_tile_sublane_offset, dst_layout, - hw_generation, ) i1 = ir.IntegerType.get_signless(1) @@ -983,7 +979,6 @@ def retile_to_reduced_sublanes( src_layout: VectorLayout, src_vreg_array: np.ndarray, dst_layout: VectorLayout, - hw_generation: int, ) -> np.ndarray: """Retiles across vregs to match the destination layout when the sublane tiling dimension is reduced. @@ -993,7 +988,6 @@ def retile_to_reduced_sublanes( src_vreg_array: An array of vregs storing source tiles. dst_layout: The destination layout, with reduced sublane dimension, based on which the retiling will be performed. - hw_generation: The generation of a target hardware. Returns: A new array of vregs that store tiles based on the destination layout. @@ -1121,7 +1115,6 @@ def retile_to_reduced_sublanes( end_src_col=src_vreg_array_col_end, first_dst_tile_sublane_offset=first_dst_tile_sublane_offset, dst_layout=dst_layout, - hw_generation=hw_generation, ) if first_dst_tile_sublane_offset == 0: # No need to rotate. First dst tile is already at offset 0, which means @@ -1190,7 +1183,7 @@ def is_supported_reduced_sublanes_retile( # TODO(apaszke): Test this function properly def relayout( - v: ir.Value, src: VectorLayout, dst: VectorLayout, hw_generation: int + v: ir.Value, src: VectorLayout, dst: VectorLayout ) -> ValueLike: """Changes the layout of a vector value. @@ -1198,7 +1191,6 @@ def relayout( v: The value to relayout. src: The current layout of v. dst: The target layout of v. - hw_generation: The generation of a target hardware. Returns: A new MLIR vector value, laid out as requested by dst. @@ -1354,7 +1346,6 @@ def relayout( src_layout=src, src_vreg_array=src_tiles, dst_layout=dst, - hw_generation=hw_generation, ) src = dst @@ -1597,7 +1588,7 @@ def apply_layout_op(ctx: RewriteContext, op: ir.OpView): continue with ir.InsertionPoint(op), op.location: new_v = relayout( - v, src=lo, dst=li, hw_generation=ctx.hardware_generation + v, src=lo, dst=li ).result ctx.set_operand(op, idx, new_v) else: