We can use the first cell of each row and column as a note to mark rows. Then we can zero those out on the second pass.

class Solution:
    def setZeroes(self, matrix: List[List[int]]) -> None:
        m, n = len(matrix), len(matrix[0])
 
        # Step 1: check if first row has a zero
        first_row_zero = any(matrix[0][j] == 0 for j in range(n))
 
        # Step 2: check if first column has a zero
        first_col_zero = any(matrix[i][0] == 0 for i in range(m))
 
        # Step 3: mark rows and cols using first row/col
        for i in range(1, m):
            for j in range(1, n):
                if matrix[i][j] == 0:
                    matrix[i][0] = 0
                    matrix[0][j] = 0
 
        # Step 4: zero out cells based on marks
        for i in range(1, m):
            for j in range(1, n):
                if matrix[i][0] == 0 or matrix[0][j] == 0:
                    matrix[i][j] = 0
 
        # Step 5: handle first row and first col
        if first_row_zero:
            for j in range(n):
                matrix[0][j] = 0
 
        if first_col_zero:
            for i in range(m):
                matrix[i][0] = 0