Skip to main content

Branch

"Branch" is a transformer that runs multiple pipelines simultaneously and optionally merges their outputs. It can execute pipelines in parallel or sequentially.

Why do we need branch?

The main need for "Branch" is to enable the creation of complex flows in pipelines.

Branch in depth

"Branch" gets a dictionary of sframe keys for the keys and a pipeline or list of pipelines for the values. Based on this config, for every pair of key and value, sframe retrieves from the input sframe based on the key and if the value is a single pipeline, it is passed to the pipeline and added to the result dictionary with the same key. But if the pipeline is a list, the result of each pipeline will be added to the result dictionary in this format: key + "__" + index of that pipeline.

Note that if the input sframe is not grouped, the result for each pipeline will be stored with a common template as mentioned, but the sframe that is retrieved from the input sframe for each pipeline is, in fact, the same input sframe.

Parallel or Sequential Execution

"Branch" supports running in parallel or sequentially. In parallel running, "Branch" uses multi-processing in Python. The maximum number of workers can be set with max_workers in the "Branch" constructor.

Branch(parallel=True, max_workers=16)

The default value of max_workers is 16.

How to configure

"Branch" gets the pipe_map dictionary. Its keys are strings that define the key of that pipeline, and the values can be a pipeline or a list of pipelines. The key of pipe_map points to the sframe key in the input sframe to the pipeline. So based on every key and value in pipe_map, the sframe with that key is separated and goes to the pipeline that is the same value.

from seshat.data_class import SFrame, DFrame, GroupSFrame
from seshat.transformer.pipeline import Pipeline
from seshat.transformer.pipeline.branch import Branch


class Transformer1:
def __call__(self, sf: SFrame) -> SFrame:
# Normalize data in sf
return sf

class Transformer2:
def __call__(self, sf: SFrame) -> SFrame:
# Filter rows in sf
return sf

class Transformer3:
def __call__(self, sf: SFrame) -> SFrame:
# Preprocessing
return sf

class Transformer4:
def __call__(self, sf: SFrame) -> SFrame:
# Preprocessing
return sf


data_1 = {"A": ["foo", "bar"], "B": [1, 2]}
data_2 = {"A": ["baz", "qux"], "B": [3, 4]}
sf_1 = DFrame.from_raw(data_1)
sf_2 = DFrame.from_raw(data_2)
sf = GroupSFrame(children={"default": sf_1, "address": sf_2})
default_pipeline = Pipeline(pipes=[Transformer1(), Transformer2()])
address_pipeline = Pipeline(pipes=[Transformer3(), Transformer4()])

branch = Branch(pipe_map={"default": default_pipeline, "address": address_pipeline})
result_sf = branch(sf)
print(list(result_sf.keys))