You can also use group aggregation functionality (see also https://arrow.apache.org/docs/python/compute.html#grouped-aggregations):
import numpy as np
import pyarrow as pa
from typing import Literal
def deduplicate(table: pa.Table, keys: str | list[str], op: Literal["one", "first", "last"]="one") -> pa.Table:
table=table.append_column('__index__', pa.array(np.arange(len(dt))))
grps=table.group_by(keys, use_threads=(op == "one")).aggregate([('__index__', op)])
table=table.take(grps['__index___'+op])
return table.drop_columns(['__index__'])