Branch
"Branch" is a transformer that runs multiple pipelines simultaneously and optionally merges their outputs. It can execute pipelines in parallel or sequentially.
Why do we need branch?
The main need for "Branch" is to enable the creation of complex flows in pipelines.
Branch in depth
"Branch" gets a dictionary of sframe
keys for the keys and a pipeline or list of pipelines for the values. Based on this
config, for every pair of key and value, sframe
retrieves from the input sframe
based on the key and if the value is a single
pipeline, it is passed to the pipeline and added to the result dictionary with the same key. But if the pipeline is a
list, the result of each pipeline will be added to the result dictionary in this format: key + "__" + index of that
pipeline.
Note that if the input sframe
is not grouped, the result for each pipeline will be stored with a common template as
mentioned, but the sframe
that is retrieved from the input sframe
for each pipeline is, in fact, the same input sframe
.
Parallel or Sequential Execution
"Branch" supports running in parallel or sequentially. In parallel running, "Branch" uses multi-processing in Python.
The maximum number of workers can be set with max_workers
in the "Branch" constructor.
Branch(parallel=True, max_workers=16)
The default value of max_workers
is 16.
How to configure
"Branch" gets the pipe_map
dictionary. Its keys are strings that define the key of that pipeline, and the values can
be a pipeline or a list of pipelines. The key of pipe_map
points to the sframe
key in the input sframe
to the pipeline. So
based on every key and value in pipe_map
, the sframe
with that key is separated and goes to the pipeline that is the same
value.
from seshat.data_class import SFrame, DFrame, GroupSFrame
from seshat.transformer.pipeline import Pipeline
from seshat.transformer.pipeline.branch import Branch
class Transformer1:
def __call__(self, sf: SFrame) -> SFrame:
# Normalize data in sf
return sf
class Transformer2:
def __call__(self, sf: SFrame) -> SFrame:
# Filter rows in sf
return sf
class Transformer3:
def __call__(self, sf: SFrame) -> SFrame:
# Preprocessing
return sf
class Transformer4:
def __call__(self, sf: SFrame) -> SFrame:
# Preprocessing
return sf
data_1 = {"A": ["foo", "bar"], "B": [1, 2]}
data_2 = {"A": ["baz", "qux"], "B": [3, 4]}
sf_1 = DFrame.from_raw(data_1)
sf_2 = DFrame.from_raw(data_2)
sf = GroupSFrame(children={"default": sf_1, "address": sf_2})
default_pipeline = Pipeline(pipes=[Transformer1(), Transformer2()])
address_pipeline = Pipeline(pipes=[Transformer3(), Transformer4()])
branch = Branch(pipe_map={"default": default_pipeline, "address": address_pipeline})
result_sf = branch(sf)
print(list(result_sf.keys))