Splitter
Splitting data is very useful when working on data for ML purposes. For example, splitting data for test and train sets
is very common. Splitter
is another transformer that takes input sf
and returns train and test sets in a dictionary
with test
and train
keys.
Splitter
generally works on the default sf
and divides it into test and train sets. But for cases where you want to
split all multi sf
, you can use the MultiSplitter
, which will be discussed later.
Every Splitter
has a percent
argument to specify the percentage of test data from the entire input sf
. For
example, if you have 10 rows and set the percent
to 0.8, the test set will have 8 rows and the train set will have 2
rows.
All Splitter
classes also have a clear_input
argument. If it is true, all children of the splitter except test and
train will be removed from the output. If it is false, all children plus test and train will return as a dictionary.
Output Dictionary Keys
As mentioned before, the result of a splitter is a dictionary. You can define the keys of this dictionary with
the group_keys
argument. The group keys have this format:
{
"default": "default_sf_key",
"test": "key_of_test_sf_in_output_dict",
"train": "key_of_train_sf_in_output_dict",
}
Random
The RandomSplitter
can randomly separate input data into test and train datasets. Like any other splitter, you can
define the percent
for it. You can also set the seed
to ensure that results are reproducible. The default value
of seed
is 42.
from seshat.data_class import DFrame
from seshat.transformer.splitter.random import RandomSplitter
sf = DFrame.from_raw({"A": ["foo", "bar", "baz", "qux"]})
splitter = RandomSplitter(percent=0.8)
result = splitter(sf)
isinstance(result, dict)
>>> True
result["test"].data
>>>
A
1 bar
result["train"].data
>>>
A
3 qux
0 foo
2 baz
Block
The block number is important information in blockchain data because it's a monotonically increasing value over time. So
there is a built-in splitter for block numbers called BlockSplitter
. This splitter finds the minimum block number and,
with the input percent, determines the cutoff value. All rows with block numbers above the cutoff will be considered as
test data, and the rest as train data.
To use this, you must provide block_num_col
to indicate the block number column name for the input sf
. The percent
argument can also be set as a constructor argument.
from seshat.data_class import DFrame
from seshat.transformer.splitter.block import BlockSplitter
sf = DFrame.from_raw({"BLOCK_NUMBER": [1, 40, 50, 60, 70, 80, 90, 100, 2, 3]})
splitter = BlockSplitter(block_num_col="BLOCK_NUMBER")
result = splitter(sf)
result["test"].data
>>>
BLOCK_NUMBER
5 80
6 90
7 100
result["train"].data
>>>
BLOCK_NUMBER
0 1
1 40
2 50
3 60
4 70
8 2
9 3
This splitter also checks the type of the block number column. Since this column must be numeric, if the column type is not numeric, it will be changed.
MultiSplitter
As mentioned at the beginning of this page, the MultiSplitter
is a special splitter for splitting more than one sf
.
This splitter takes a list of splitters and passes the input sf
to all these child splitters. MultiSplitter
ensures
that the clear_input
of every child splitter is false.
The result of MultiSplitter
is a dictionary with test and train keys. As mentioned before, these keys come from
the group_keys
argument. Every splitter's result is a dictionary of test and train sets. The test sets will be added
to a group sframe
with the key test
in the result of MultiSplitter
, and train sets will be added to the
group sframe
with the key train
.
from seshat.data_class import DFrame, GroupSFrame
from seshat.transformer.splitter.base import MultiSplitter
from seshat.transformer.splitter.block import BlockSplitter
from seshat.transformer.splitter.random import RandomSplitter
sf = GroupSFrame(
children={
"sf1": DFrame.from_raw(
{"BLOCK_NUMBER": [1, 40, 50, 60, 70, 80, 90, 100, 2, 3]}
),
"sf2": DFrame.from_raw({"A": ["foo", "bar", "baz", "qux"]}),
}
)
block_splitter = BlockSplitter(
block_num_col="BLOCK_NUMBER",
group_keys={"default": "sf1", "test": "test", "train": "train"},
)
random_splitter = RandomSplitter(
group_keys={"default": "sf2", "test": "test", "train": "train"},
)
splitter = MultiSplitter(splitters=[block_splitter, random_splitter])
result = splitter(sf)
list(result["train"].keys())
>>> ['sf1', 'sf2']
list(result["test"].keys())
>>> ['sf1', 'sf2']