Skip to main content

Splitter

Splitting data is very useful when working on data for ML purposes. For example, splitting data for test and train sets is very common. Splitter is another transformer that takes input sf and returns train and test sets in a dictionary with test and train keys. Splitter generally works on the default sf and divides it into test and train sets. But for cases where you want to split all multi sf, you can use the MultiSplitter, which will be discussed later.

Every Splitter has a percent argument to specify the percentage of test data from the entire input sf. For example, if you have 10 rows and set the percent to 0.8, the test set will have 8 rows and the train set will have 2 rows.

All Splitter classes also have a clear_input argument. If it is true, all children of the splitter except test and train will be removed from the output. If it is false, all children plus test and train will return as a dictionary.

Output Dictionary Keys

As mentioned before, the result of a splitter is a dictionary. You can define the keys of this dictionary with the group_keys argument. The group keys have this format:

{
"default": "default_sf_key",
"test": "key_of_test_sf_in_output_dict",
"train": "key_of_train_sf_in_output_dict",
}

Random

The RandomSplitter can randomly separate input data into test and train datasets. Like any other splitter, you can define the percent for it. You can also set the seed to ensure that results are reproducible. The default value of seed is 42.

from seshat.data_class import DFrame
from seshat.transformer.splitter.random import RandomSplitter

sf = DFrame.from_raw({"A": ["foo", "bar", "baz", "qux"]})

splitter = RandomSplitter(percent=0.8)
result = splitter(sf)
isinstance(result, dict)
>>> True

result["test"].data
>>>
A
1 bar
result["train"].data
>>>
A
3 qux
0 foo
2 baz

Block

The block number is important information in blockchain data because it's a monotonically increasing value over time. So there is a built-in splitter for block numbers called BlockSplitter. This splitter finds the minimum block number and, with the input percent, determines the cutoff value. All rows with block numbers above the cutoff will be considered as test data, and the rest as train data. To use this, you must provide block_num_col to indicate the block number column name for the input sf. The percent argument can also be set as a constructor argument.

from seshat.data_class import DFrame
from seshat.transformer.splitter.block import BlockSplitter

sf = DFrame.from_raw({"BLOCK_NUMBER": [1, 40, 50, 60, 70, 80, 90, 100, 2, 3]})

splitter = BlockSplitter(block_num_col="BLOCK_NUMBER")

result = splitter(sf)
result["test"].data
>>>
BLOCK_NUMBER
5 80
6 90
7 100

result["train"].data
>>>
BLOCK_NUMBER
0 1
1 40
2 50
3 60
4 70
8 2
9 3

This splitter also checks the type of the block number column. Since this column must be numeric, if the column type is not numeric, it will be changed.

MultiSplitter

As mentioned at the beginning of this page, the MultiSplitter is a special splitter for splitting more than one sf. This splitter takes a list of splitters and passes the input sf to all these child splitters. MultiSplitter ensures that the clear_input of every child splitter is false.

The result of MultiSplitter is a dictionary with test and train keys. As mentioned before, these keys come from the group_keys argument. Every splitter's result is a dictionary of test and train sets. The test sets will be added to a group sframe with the key test in the result of MultiSplitter, and train sets will be added to the group sframe with the key train.

from seshat.data_class import DFrame, GroupSFrame
from seshat.transformer.splitter.base import MultiSplitter
from seshat.transformer.splitter.block import BlockSplitter
from seshat.transformer.splitter.random import RandomSplitter

sf = GroupSFrame(
children={
"sf1": DFrame.from_raw(
{"BLOCK_NUMBER": [1, 40, 50, 60, 70, 80, 90, 100, 2, 3]}
),
"sf2": DFrame.from_raw({"A": ["foo", "bar", "baz", "qux"]}),
}
)

block_splitter = BlockSplitter(
block_num_col="BLOCK_NUMBER",
group_keys={"default": "sf1", "test": "test", "train": "train"},
)
random_splitter = RandomSplitter(
group_keys={"default": "sf2", "test": "test", "train": "train"},
)

splitter = MultiSplitter(splitters=[block_splitter, random_splitter])
result = splitter(sf)

list(result["train"].keys())
>>> ['sf1', 'sf2']
list(result["test"].keys())
>>> ['sf1', 'sf2']