Source code for law.contrib.pyarrow.util

# coding: utf-8

"""
PyArrow related utilities.
"""

from __future__ import annotations

__all__ = ["merge_parquet_files", "merge_parquet_task"]

import os
import shutil
import pathlib
import collections

from law.task.base import Task
from law.target.file import FileSystemFileTarget, get_path
from law.target.local import LocalFileTarget, LocalDirectoryTarget
from law.util import map_verbose, human_bytes
from law._types import Sequence, Any, Callable


[docs] def merge_parquet_files( src_paths: Sequence[str | pathlib.Path | FileSystemFileTarget], dst_path: str | pathlib.Path | FileSystemFileTarget, force: bool = True, callback: Callable[[int], Any] | None = None, writer_opts: dict[str, Any] | None = None, copy_single: bool = False, skip_empty: bool = True, target_row_group_size: int = 0, ) -> str: """ Merges parquet files in *src_paths* into a new file at *dst_path*. Intermediate directories are created automatically. When *dst_path* exists and *force* is *True*, the file is removed first. Otherwise, an exception is thrown. *callback* can refer to a callable accepting a single integer argument representing the index of the file after it was merged. *writer_opts* can be a dictionary of keyword arguments that are passed to the *ParquetWriter* instance. When *src_paths* contains only a single file and *copy_single* is *True*, the file is copied to *dst_path* and no merging takes place. Files containing empty tables are skipped unless *skip_empty* is *False*. When *target_row_group_size* is a positive number, the merging is done on the level of particular row groups. These groups are merged in-memory such that each resulting group stored on disk, potentially except for the last one, will *target_row_group_size* rows. The absolute, expanded *dst_path* is returned. """ import pyarrow as pa # type: ignore[import-untyped, import-not-found] import pyarrow.parquet as pq # type: ignore[import-untyped, import-not-found] if not src_paths: raise Exception("cannot merge empty list of parquet files") # default callable if not callable(callback): callback = lambda i: None # prepare paths abspath = lambda p: os.path.abspath(os.path.expandvars(os.path.expanduser(get_path(p)))) src_paths = list(map(abspath, src_paths)) dst_path = abspath(dst_path) # prepare the dst directory dir_name = os.path.dirname(dst_path) if not os.path.exists(dir_name): os.makedirs(dir_name) # remove the file first when existing if os.path.exists(dst_path): if not force: raise Exception(f"destination path existing while force is False: {dst_path}") os.remove(dst_path) if target_row_group_size <= 0: # trivial case if copy_single and len(src_paths) == 1: shutil.copy(str(src_paths[0]), dst_path) callback(0) else: # extract the schema from the first non-table table for i, path in enumerate(src_paths): if i == len(src_paths) - 1 or pq.ParquetFile(path).metadata.num_rows > 0: schema = pq.read_schema(path) break else: raise RuntimeError("could not find a non-empty table to extract the schema from") # iterate and add tables with pq.ParquetWriter(dst_path, schema, **(writer_opts or {})) as writer: # write all tables for i, path in enumerate(src_paths): table = pq.read_table(path) if not skip_empty or table.num_rows > 0: writer.write_table(table) callback(i) else: # more complex behavior when aiming at specific row group sizes # create a work queue with file handle, file index and row group index # also, extract the schema from the first non-empty table q: collections.deque[tuple[pq.ParquetFile, int, int]] = collections.deque() schema = None for i, path in enumerate(src_paths): f = pq.ParquetFile(path) q.extend([(f, i, g) for g in range(f.num_row_groups)]) if schema is None and (i == (len(src_paths) - 1) or f.metadata.num_rows > 0): schema = pq.read_schema(path) # start iterative merging tables: collections.deque[tuple[pa.Table, int]] = collections.deque() cur_size = 0 with pq.ParquetWriter(dst_path, schema, **(writer_opts or {})) as writer: while q: # read the next row group f, i, g = q.popleft() table = f.read_row_group(g) if not skip_empty or table.num_rows > 0: tables.append((table, table.num_rows)) cur_size += table.num_rows # write row groups when the size is reached while cur_size >= target_row_group_size: merge_tables = [] merge_size = 0 # from the front, check which tables should be merged, potentially splitting while tables: table, size = tables.popleft() missing_size = target_row_group_size - merge_size if size < missing_size: merge_tables.append(table) merge_size += size else: merge_tables.append(table[:missing_size]) merge_size += missing_size if size > missing_size: tables.appendleft((table[missing_size:], size - missing_size)) break writer.write_table(pa.concat_tables(merge_tables)) cur_size -= merge_size # after the last group is handled, invoke callback and close the file if g == f.num_row_groups - 1: callback(i) f.close() # write remaining tables if tables: writer.write_table(pa.concat_tables([table for table, _ in tables])) return dst_path
[docs] def merge_parquet_task( task: Task, inputs: Sequence[str | pathlib.Path | FileSystemFileTarget], output: str | pathlib.Path | FileSystemFileTarget, local: bool = False, cwd: str | pathlib.Path | LocalDirectoryTarget | None = None, force: bool = True, **kwargs: Any, ) -> None: """ This method is intended to be used by tasks that are supposed to merge parquet files, e.g. when inheriting from :py:class:`law.contrib.tasks.MergeCascade`. *inputs* should be a sequence of targets that represent the files to merge into *output*. When *local* is *False* and files need to be copied from remote first, *cwd* can be a set as the dowload directory. When empty, a temporary directory is used. The *task* itself is used to print and publish messages via its :py:meth:`law.Task.publish_message` and :py:meth:`law.Task.publish_step` methods. When *force* is *True*, any existing output file is overwritten. All additional *kwargs* are forwarded to :py:func:`merge_parquet_files` which is used internally for the actual merging. """ abspath = lambda p: os.path.abspath(os.path.expandvars(os.path.expanduser(get_path(p)))) # ensure inputs are targets inputs = [ inp if isinstance(inp, FileSystemFileTarget) else LocalFileTarget(abspath(inp)) for inp in inputs ] # ensure output is a target if not isinstance(output, FileSystemFileTarget): output = LocalFileTarget(abspath(output)) def merge(inputs, output): with task.publish_step(f"merging {len(inputs)} parquet files ...", runtime=True): # clear the output if necessary if output.exists() and force: output.remove() # merge merge_parquet_files([inp.abspath for inp in inputs], output.abspath, **kwargs) stat = output.exists(stat=True) if not stat: raise Exception(f"output '{output.abspath}' not creating during merging") # print the size output_size = human_bytes(stat.st_size, fmt=True) task.publish_message(f"merged file size: {output_size}") if local: # everything is local, just merge merge(inputs, output) else: # when not local, we need to fetch files first into the cwd if isinstance(cwd, str): cwd = LocalDirectoryTarget(abspath(cwd)) elif not isinstance(cwd, LocalDirectoryTarget): cwd = LocalDirectoryTarget(is_tmp=True) cwd.touch() # fetch with task.publish_step("fetching inputs ...", runtime=True): def fetch(inp: FileSystemFileTarget) -> LocalFileTarget: local_inp: LocalFileTarget = cwd.child(inp.unique_basename, type="f") # type: ignore[assignment] # noqa inp.copy_to_local(local_inp, cache=False) return local_inp def callback(i: int) -> None: task.publish_message(f"fetch file {i + 1} / {len(inputs)}") local_inputs = map_verbose(fetch, inputs, every=5, callback=callback) # type: ignore[arg-type] # noqa # merge into a localized output with output.localize("w", cache=False) as local_output: merge(local_inputs, local_output)