# -*- coding: utf-8 -*-
import numpy as np
import pandas as pd
from ..events.events_find import _events_find_label
from ..misc import listify
[docs]def epochs_create(
data,
events=None,
sampling_rate=1000,
epochs_start=0,
epochs_end=1,
event_labels=None,
event_conditions=None,
baseline_correction=False,
):
"""Epoching a dataframe.
Parameters
----------
data : DataFrame
A DataFrame containing the different signal(s) as different columns.
If a vector of values is passed, it will be transformed in a DataFrame
with a single 'Signal' column.
events : list or ndarray or dict
Events onset location. If a dict is passed (e.g., from ``events_find()``),
will select only the 'onset' list. If an integer is passed,
will use this number to create an evenly spaced list of events. If None,
will chunk the signal into successive blocks of the set duration.
sampling_rate : int
The sampling frequency of the signal (in Hz, i.e., samples/second).
epochs_start : int
Epochs start relative to events_onsets (in seconds). The start can be negative to
start epochs before a given event (to have a baseline for instance).
epochs_end : int
Epochs end relative to events_onsets (in seconds).
event_labels : list
A list containing unique event identifiers. If `None`, will use the event index number.
event_conditions : list
An optional list containing, for each event, for example the trial category, group or
experimental conditions.
baseline_correction : bool
Defaults to False.
Returns
----------
dict
A dict containing DataFrames for all epochs.
See Also
----------
events_find, events_plot, epochs_to_df, epochs_plot
Examples
----------
>>> import neurokit2 as nk
>>>
>>> # Get data
>>> data = nk.data("bio_eventrelated_100hz")
>>>
>>> # Find events
>>> events = nk.events_find(data["Photosensor"],
... threshold_keep='below',
... event_conditions=["Negative", "Neutral", "Neutral", "Negative"])
>>> fig1 = nk.events_plot(events, data)
>>> fig1 #doctest: +SKIP
>>>
>>> # Create epochs
>>> epochs = nk.epochs_create(data, events, sampling_rate=100, epochs_end=3)
>>> fig2 = nk.epochs_plot(epochs)
>>> fig2 #doctest: +SKIP
>>>
>>> # Baseline correction
>>> epochs = nk.epochs_create(data, events, sampling_rate=100, epochs_end=3, baseline_correction=True)
>>> fig3 = nk.epochs_plot(epochs)
>>> fig3 #doctest: +SKIP
>>>
>>> # Chunk into n blocks of 1 second
>>> epochs = nk.epochs_create(data, sampling_rate=100, epochs_end=1)
"""
# Santize data input
if isinstance(data, tuple): # If a tuple of data and info is passed
data = data[0]
if isinstance(data, (list, np.ndarray, pd.Series)):
data = pd.DataFrame({"Signal": list(data)})
# Sanitize events input
if events is None:
max_duration = (np.max(epochs_end) - np.min(epochs_start)) * sampling_rate
events = np.arange(0, len(data) - max_duration, max_duration)
if isinstance(events, int):
events = np.linspace(0, len(data), events + 2)[1:-1]
if isinstance(events, dict) is False:
events = _events_find_label({"onset": events}, event_labels=event_labels, event_conditions=event_conditions)
event_onsets = list(events["onset"])
event_labels = list(events["label"])
if "condition" in events.keys():
event_conditions = list(events["condition"])
# Create epochs
parameters = listify(
onset=event_onsets, label=event_labels, condition=event_conditions, start=epochs_start, end=epochs_end
)
# Find the maximum numbers of samples in an epoch
parameters["duration"] = np.array(parameters["end"]) - np.array(parameters["start"])
epoch_max_duration = int(max((i * sampling_rate for i in parameters["duration"])))
# Extend data by the max samples in epochs * NaN (to prevent non-complete data)
length_buffer = epoch_max_duration
buffer = pd.DataFrame(index=range(length_buffer), columns=data.columns)
data = data.append(buffer, ignore_index=True, sort=False)
data = buffer.append(data, ignore_index=True, sort=False)
# Adjust the Onset of the events for the buffer
parameters["onset"] = [i + length_buffer for i in parameters["onset"]]
epochs = {}
for i, label in enumerate(parameters["label"]):
# Find indices
start = parameters["onset"][i] + (parameters["start"][i] * sampling_rate)
end = parameters["onset"][i] + (parameters["end"][i] * sampling_rate)
# Slice dataframe
epoch = data.iloc[int(start) : int(end)].copy()
# Correct index
epoch["Index"] = epoch.index.values - length_buffer
epoch.index = np.linspace(
start=parameters["start"][i], stop=parameters["end"][i], num=len(epoch), endpoint=True
)
if baseline_correction is True:
baseline_end = 0 if epochs_start <= 0 else epochs_start
epoch = epoch - epoch.loc[:baseline_end].mean()
# Add additional
epoch["Label"] = parameters["label"][i]
if parameters["condition"][i] is not None:
epoch["Condition"] = parameters["condition"][i]
# Store
epochs[label] = epoch
return epochs