Main Content

Interpretable Time Series Forecasting Using a Temporal Fusion Transformer

Since R2025a

This example shows how to forecast electricity usage using a temporal fusion transformer (TFT) [1]. TFT is an attention-based network that you can use for time series forecasting. The network uses attention mechanisms and importance weighting to provide interpretable insights into the importance of different time steps and features.

A TFT takes as input past values of a time series, along with other static and time-varying inputs, and outputs a prediction of future values for a specified number of time steps. The inputs to the TFT can be:

  • Known time-varying inputs — Inputs where the future values are known ahead of prediction time.

  • Unknown time-varying inputs — Inputs where the future values are not known ahead of prediction time. The response variable (the value you want to predict) is always an unknown input.

  • Static inputs — Inputs that do not change with time.

The inputs to the network can be numeric or categorical.

This example uses a TFT to forecast electricity usage of multiple clients at hourly intervals over one day, using the previous seven days' usage as input. The model also uses the following inputs:

  1. Hour of day (known time-varying, numeric)

  2. Day of week (known time-varying, numeric)

  3. Hours since start of measurement (known time-varying, numeric)

  4. Client ID (static, categorical).

For reproducibility, set the random seed.

rng(0)

Load and Visualize Data

This example uses a preprocessed version of the Electricity Load Diagrams data set available in the UCI Machine Learning Repository [2] and licensed under CC BY 4.0. The original data set contains values in kW of 370 clients, logged every 15 minutes from 2011 to 2014. The subset contains the hourly electricity consumption (in kWh) from January 1st 2014 to September 8th 2014, partitioned into training, validation, and test sets. Missing data is represented by NaN values. The data is approximately 10MB in size. Download the data from the MathWorks website.

filenameData = matlab.internal.examples.downloadSupportFile('nnet','data/ElectricityLoadDiagrams2014.mat')
load(filenameData)

Each data set contains a timetable. The rows represent the logged kWh. The columns represent the time and the 370 clients.

Plot the first 192 hours (eight days) of the sixth client in the training data set. The data shows a strong seasonal trend with a 24-hour period.

clientIdx = 6;
hrs = 1:192;
plot(tftUsageDataTrain.Time(hrs),tftUsageDataTrain{hrs,clientIdx});
xlabel("Time")
ylabel("Electricity Consumption (kWh)")

Extract the IDs of the electricity clients.

clientIDs = string(tftUsageDataTrain.Properties.VariableNames);

Prepare Data for Forecasting

Use the helper function addTimePredictorVariables to add extra columns to each table for hour of day, day of week, and number of hours since the start of the training time series.

startTime = tftUsageDataTrain.Properties.StartTime;

tftUsageDataTrain = addTimePredictorVariables(tftUsageDataTrain,startTime);
tftUsageDataValidation = addTimePredictorVariables(tftUsageDataValidation,startTime);
tftUsageDataTest = addTimePredictorVariables(tftUsageDataTest,startTime);

Normalize Data

Normalize the training data to have a mean of zero and a standard deviation of one. Normalize the validation and test data using the training data statistics. When computing the mean and standard deviation, use the "omitmissing" flag to ignore NaN values in the data.

trainingMean = mean(tftUsageDataTrain,"omitmissing");
trainingStd = std(tftUsageDataTrain,"omitmissing");

dataTrainNormalized = (tftUsageDataTrain - trainingMean) ./ trainingStd;
dataValidationNormalized = (tftUsageDataValidation - trainingMean) ./ trainingStd;
dataTestNormalized = (tftUsageDataTest - trainingMean) ./ trainingStd;

Chunk Data

Use the helper function chunkData to randomly sample the data into 192-hour (eight-day) chunks, which can overlap.

This example uses 45,000 samples for training, 5,000 samples for validation, and 5,000 samples for testing. To reproduce the results from [1], set numTrainSamples to 450,000, numValSamples to 50,000, and numTestSamples to "all".

sampleLength = 192;

numTrainSamples = 45000;
numValSamples = 5000;
numTestSamples = 5000;

[samplesTrain,idsTrain] = chunkData(dataTrainNormalized,sampleLength,numTrainSamples,clientIDs);
[samplesValidation,idsValidation] = chunkData(dataValidationNormalized,sampleLength,numValSamples,clientIDs);
[samplesTest,idsTest] = chunkData(dataTestNormalized,sampleLength,numTestSamples,clientIDs);

Create Datastore

Create datastores to store the multi-input data and process it for training the neural network using the createMultiInputDatastore helper function. Use 168 hours (7 days) of electricity load as input to the network and 24 hours (1 day) of electricity load as the target to forecast.

numPastTimeSteps = 168;
numFutureTimeSteps = 24;

dsTrain = createMultiInputDatastore(samplesTrain,numPastTimeSteps,idsTrain);
dsValidation = createMultiInputDatastore(samplesValidation,numPastTimeSteps,idsValidation);
dsTest = createMultiInputDatastore(samplesTest,numPastTimeSteps,idsTest);

Plot an example of the preprocessed training data.

[elecIn,hour,day,hoursFromStart,id,elecOut] = dsTrain.preview{:};

figure
tiledlayout(4,1)

nexttile
plot(1:numPastTimeSteps,elecIn)
hold on
plot(numPastTimeSteps+1:sampleLength,elecOut,'--')
hold off
legend(["Input" "Target"],Location="northwest")
ylabel("Electricity")

nexttile
plot(1:sampleLength,hour)
ylabel("Hour")

nexttile
plot(1:sampleLength,day)
ylabel("Day")

nexttile
plot(1:sampleLength,hoursFromStart)
ylabel("Hours from start")

Create and Explore Temporal Fusion Transformer Network

Use the helper function createTFTNetwork to create the temporal fusion transformer network. To access this function, open the example as a live script.

Create a network with an architecture following [1]:

  • Specify one unknown, time-varying, continuous input (electricity).

  • Specify three known, time-varying, continuous inputs (hour, day, and hours from start).

  • Specify one static, categorical input (ID), with 370 categories.

  • Use 160 hidden units in each component.

  • Use four heads in the self-attention component.

  • Output predictions for three quantiles.

  • Use a dropout probability of 0.1.

inputNames = ["electricity" "hour" "day" "hoursFromStart" "id"];
unknownTimeVaryingInputIdx = 1;
knownTimeVaryingInputIdx = [2 3 4];
staticInputIdx = 5;
categoricalInputIdx = 5;
numCategories = 370;
numHiddenUnits = 160;
numAttentionHeads = 4;
numQuantiles = 3;
dropoutProbability = 0.1;

net = createTFTNetwork(inputNames,unknownTimeVaryingInputIdx, ...
    knownTimeVaryingInputIdx,staticInputIdx,categoricalInputIdx, ...
    numCategories,numHiddenUnits,numAttentionHeads,numPastTimeSteps, ...
    numFutureTimeSteps,numQuantiles, ...
    DropoutProbability=dropoutProbability);

Visualize Network

To view the network, use the Deep Network Designer app.

deepNetworkDesigner(net)

The network is made up of several different components contained in networkLayer objects. To view the contents of a network layer, double click the layer in Deep Network Designer.

Gated Linear Units

TFTs use gated linear unit (GLU) activations. To view an example of a GLU, double-click on the lstm_gate layer in Deep Network Designer.

A GLU is a learnable activation function that allows the network to control how much of the input to propagate through the network by returning high or low values from the sigmoid activation. To create a GLU network layer, use the gluNetworkLayer helper function included with this example.

Gated Residual Networks

TFTs use gated residual networks (GRN) for nonlinear processing. These consist of fully connected layers with exponential linear unit (ELU) activations, and a gated skip connection using a GRU. To view an example of a GRN, double-click on the static_context_varselect layer in Deep Network Designer.

If the nonlinear processing is not helpful for the predictions of the network, then the GRU learns to suppress the nonlinear branch, effectively skipping the unit. To create a GRN network layer, use the grnNetworkLayer function included with this example.

Variable Selection Networks

The known, unknown, and static data are inputs into the variable selection networks. A variable selection network aggregates the inputs with a learned weighting that depends on the input values. To view an example of a variable selection network, double-click on the layer future_varselect in Deep Network Designer. This layer performs variable selection on three input variables: the future values of the hour, day, and hoursFromStart inputs.

The variable selection network allows the TFT to amplify important inputs and suppress less important ones. You can interpret the outputs of the softmax layer as importance scores, allowing you to see which input features the network learns are most important for forecasting. To create a variable selection network layer, use the variableSelectionNetworkLayer function included with this example.

Interpretable Multi-Head Self-Attention

The TFT uses a variant of multi-head attention which, rather than concatenating, takes the average attention weights from each head and shares the values between attention heads. To view an example of an interpretable multi-head self-attention layer, double-click on the attn layer in Deep Network Designer. This layer has four attention heads.

You can interpret the attention scores from this layer because the network uses the same values for each head, meaning you can directly compare the scores from each head. To create an interpretable multi-head self-attention network layer, use the interpretableSelfAttentionNetworkLayer function included with this example.

Specify Training Options

Specify the training options.

  • Train using Adam optimization.

  • Train for 5 epochs.

  • Use a mini-batch size of 64.

  • Use an initial learning rate of 0.001.

  • Use a gradient threshold of 0.01.

  • Shuffle the data every epoch.

  • Use dsValidation as validation data.

  • Validate once per epoch.

  • Display the training progress in a plot.

  • Disable the verbose output.

minibatchsize = 64;
validationFrequency = ceil(numTrainSamples/minibatchsize);

options = trainingOptions("adam", ...
    MaxEpochs=5, ...
    MiniBatchSize=64, ...
    InitialLearnRate=0.001, ...
    GradientThreshold=0.01, ...
    Shuffle="every-epoch", ...
    ValidationData=dsValidation, ...
    ValidationFrequency=validationFrequency, ...
    Plots="training-progress", ...
    Verbose=false);

Train Temporal Fusion Transformer

Train the temporal fusion transformer network using the trainnet function. Use the custom loss function quantileLoss to train the network to predict the 10th, 50th, and 90th percentile forecasts. By default, the trainnet function uses a GPU if one is available. Using a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For more information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the function uses the CPU. To specify the execution environment, use the ExecutionEnvironment training option.

Training the model is a computationally expensive process. To save time while running this example, set doTraining to false to load a pretrained network. The pretrained network is approximately 10MB in size. To train the network yourself, set doTraining to true.

quantiles = [0.1; 0.5; 0.9];

doTraining = false;
if doTraining
    trainedNet = trainnet(dsTrain,net,@(Y,T) quantileLoss(Y,T,quantiles,DataFormat="CBT"),options);
else
    fileNameNetwork = matlab.internal.examples.downloadSupportFile('nnet','data/ElectricityLoadForecastingTemporalFusionTransformer.mat');
    load(fileNameNetwork)
end

Test Network on Unseen Data

Use the minibatchpredict function to use the trained network to forecast electricity usage values on the test set. Use the Outputs name-value argument to also compute the variable selection importance scores from the network. Denormalize the predicted values using the training data statistics.

[testPredictions,testPastImportance,testFutureImportance] = minibatchpredict(trainedNet, ...
    dsTest, ...
    Outputs=["quantile_out","past_varselect/scores","future_varselect/scores"]);

denormalizedTestPredictions = denormalize(testPredictions,idsTest,trainingMean,trainingStd);
denormalizedTestTargets = denormalize(samplesTest(numPastTimeSteps+1:end,1,:),idsTest,trainingMean,trainingStd);

Compute the quantile loss and q-risk values over the test set. q-risk is similar to the quantile loss, but it is calculated for each quantile separately and normalizes after computing the loss, rather than computing the loss on the normalized data.

testLoss = quantileLoss(testPredictions,samplesTest(numPastTimeSteps+1:end,1,:),quantiles)
testLoss = single

0.2530
testQRisk = qRiskMetric(denormalizedTestPredictions,denormalizedTestTargets,quantiles)
testQRisk = 1×3 single row vector

    0.0312    0.0611    0.0346

Compare the predictions with the targets. The three output channels in the prediction represent the 10th, 50th, and 90th percentile forecasts respectively. Plot the 50th percentile as the forecast and use the 10th and 90th as the bounds of the 80% confidence interval.

numTestSamples = size(samplesTest,3);
sampleIdx = randi(numTestSamples);
sampleToPlot = denormalizedTestTargets(:,:,sampleIdx);

figure
plot(sampleToPlot)
hold on
p10 = denormalizedTestPredictions(:,1,sampleIdx);
p50 = denormalizedTestPredictions(:,2,sampleIdx);
p90 = denormalizedTestPredictions(:,3,sampleIdx);
plot(p50, '--')
patch([1:numFutureTimeSteps flip(1:numFutureTimeSteps)]', ...
    [p10;flip(p90)],1, ...
    EdgeColor="none",FaceAlpha=0.1,CDataMapping="direct")

title("24-Hour Forecast")
xlabel("Hours")
ylabel("Electricity Consumption (kWh)")
legend(["Actual" "Forecast" "80% Confidence"])
hold off

Interpret Outputs of TFT Network

The TFT architecture allows you to analyze some of its individual components and interpret the relationships and patterns the model has learned.

Variable Importance

The output scores of the variable selection networks indicate the relative importance the network places on each input variable. Compute the median importance scores for the past and future inputs over all observations and time steps. The importance score for a variable is the weight that the network applies to the branch that processes that input variable. A higher importance score means the network gives a greater weight to that variable, implying it is more important to the predictions of the network.

observationDim = 3;
timeDim = 1;
medianPastImportance = median(testPastImportance,[observationDim timeDim]);
medianFutureImportance = median(testFutureImportance,[observationDim timeDim]);

Visualize the variable importance scores for past inputs. The scores show that hour of day and past electricity usage are the most important past input variables to the model.

figure
bar(medianPastImportance)
title("Importance Scores for Past Inputs")
ylabel("Importance")
xticklabels(["Electricity" "Hour" "Day" "Hours from start"]);

Visualize the variable importance scores for future inputs using a bar chart. The scores also show that hour of day is the most important future input variable to the model.

figure
bar(medianFutureImportance)
title("Importance Scores for Future Inputs")
ylabel("Importance")
xticklabels(["Hour" "Day" "Hours from start"]);

Attention Scores

You can use the scores output of the interpretable multi-head self-attention layer to analyze the most important past time steps the network uses in its prediction. Compute the attention scores for a subset of 1000 observations of the test data.

subsetSize = 1000;
sampleIdx = randsample(numTestSamples,subsetSize);
dsTestSubset = subset(dsTest,sampleIdx);

testAttnScores = minibatchpredict(trainedNet,dsTestSubset, ...
    Outputs="attn/scores",OutputDataFormats="UUUB");

Compute the mean of the attention scores for the first future time step over all observations and attention heads.

headDim = 3;
observationDim = 4;
firstFutureTimeStep = numPastTimeSteps + 1;
meanAttentionScoreT1 = mean(testAttnScores(:,firstFutureTimeStep,:,:),[headDim observationDim]);

The causal masking in the attention layer means that the attention score between any time step and time steps in its future are zero. Remove the future time steps from the mean attention score.

meanAttentionScoreT1 = meanAttentionScoreT1(1:numPastTimeSteps+1);

Visualize the attention score across time steps. The attention scores show a peak every 24 hours, reflecting the strong 24-hour seasonal trend in the data.

figure
plot(-numPastTimeSteps:0,meanAttentionScoreT1)
title("Mean Attention Score for First Forecast Time Step")
xlabel("Hours")
ylabel("Attention Score")
xlim([-numPastTimeSteps 0])

References

  1. Lim, Bryan, et al. “Temporal Fusion Transformers for Interpretable Multi-Horizon Time Series Forecasting.” International Journal of Forecasting, vol. 37, no. 4, Oct. 2021, pp. 1748–64. ScienceDirect, https://doi.org/10.1016/j.ijforecast.2021.03.012.

  2. Trindade, A. (2015). ElectricityLoadDiagrams20112014 [Dataset]. UCI Machine Learning Repository. https://doi.org/10.24432/C58C86

  3. Wen, Ruofeng, et al. A Multi-Horizon Quantile Recurrent Forecaster. arXiv:1711.11053, arXiv, 28 June 2018. arXiv.org, https://doi.org/10.48550/arXiv.1711.11053.

Supporting Functions

Predictor Creation Function

The function addTimePredictorVariables computes extra time variables and adds them to the table. The time variables are hour of day, day of week, and hours since the start time.

function tbl = addTimePredictorVariables(tbl,startTime)
tbl.Hour = hour(tbl.Time);
tbl.Day = day(tbl.Time, "dayofweek");
tbl.HoursFromStart = hours(tbl.Time - startTime);
end

Data Chunking Function

The function chunkData randomly chooses numSamples samples of length chunkLength from the input data. To chunk the entire input data, set numSamples to "all".

function [samples,ids] = chunkData(data,chunkLength,numSamples,allIDs)
% Compute all valid sample points in the form (row, column)
validSampleLocations = [];
numIDs = numel(allIDs);

for ii = 1:numIDs
    % Ignore leading and trailing NaNs
    startIdx = find(~ismissing(data.(allIDs(ii))),1);
    endIdx = find(~ismissing(data.(allIDs(ii))),1,"last");
    numValidTimeSteps = endIdx - startIdx + 1;

    if numValidTimeSteps > chunkLength
        numSamplesPerObservation = numValidTimeSteps - chunkLength + 1;
        locs = [(startIdx:startIdx+numSamplesPerObservation-1)',ii*ones(numSamplesPerObservation,1)];
        validSampleLocations = [validSampleLocations;locs];
    end
end

% Randomly choose sample points without replacement
if strcmp(numSamples,"all")
    numSamples = size(validSampleLocations,1);
    locsToChoose = 1:size(validSampleLocations,1);
else
    locsToChoose = randsample(size(validSampleLocations,1),numSamples);
end

numFeatures = 4; % Electricity, hour, day, hours from start
samples = zeros(chunkLength,numFeatures,numSamples);
ids = categorical(strings(numSamples,1),allIDs);

% Sample
for ii = 1:numSamples
    startRow = validSampleLocations(locsToChoose(ii),1);
    id = allIDs(validSampleLocations(locsToChoose(ii),2));

    samples(:,1,ii) = data.(id)(startRow:startRow+chunkLength-1);
    samples(:,2,ii) = data.Hour(startRow:startRow+chunkLength-1);
    samples(:,3,ii) = data.Day(startRow:startRow+chunkLength-1);
    samples(:,4,ii) = data.HoursFromStart(startRow:startRow+chunkLength-1);

    ids(ii) = id;
end
end

Datastore Creation Function

The function createMultiInputDatastore prepares the inputs and outputs for neural network training by creating a datastore.

function ds = createMultiInputDatastore(samples,numPastTimeSteps,ids)
adsElecInputs = arrayDatastore(samples(1:numPastTimeSteps,1,:),IterationDimension=3);
adsElecTargets = arrayDatastore(samples(numPastTimeSteps+1:end,1,:),IterationDimension=3);
adsHour = arrayDatastore(samples(:,2,:),IterationDimension=3);
adsDay = arrayDatastore(samples(:,3,:),IterationDimension=3);
adsHoursFromStart = arrayDatastore(samples(:,4,:),IterationDimension=3);
adsID = arrayDatastore(ids);

ds = combine(adsElecInputs,adsHour,adsDay,adsHoursFromStart,adsID,adsElecTargets);
end

Quantile Loss Function

The quantileLoss function computes the quantile loss [3] for the specified quantiles, summed over all quantile outputs:

L=1NTobservationsqQtime stepsQL(y,ypred,q),

where the quantile loss

QL=qmax(y-ypred,0)+(1-q)max(ypred-y,0)

Herey a target time series to forecast, ypred is the model's prediction, Nis the total number of observations, T is the total number of forecast time steps, q is the quantile being forecast, and Q is the set of all quantiles (in this example Q={0.1,0.5,0.9}). If q<0.5, then the quantile loss encourages the model to underpredict the true value. If q>0.5, then the quantile loss encourages the model to overpredict. When q=0.5 the quantile loss is the same as the L1 loss (mean absolute error).

function l = quantileLoss(Y,T,quantiles,options)
arguments
    Y
    T
    quantiles
    options.DataFormat = "TCB"
end

predictionUnderflow = T - Y;
channelDim = strfind(options.DataFormat,"C");
quantiles = shiftdim(quantiles,1-channelDim);

qLoss = quantiles .* max(predictionUnderflow,0) + (1 - quantiles) .* max(-predictionUnderflow,0);

observationDim = strfind(options.DataFormat,"B");
timeDim = strfind(options.DataFormat,"T");
numObservations = size(Y,observationDim);
numTimeSteps = size(Y,timeDim);
l = sum(qLoss,"all") / (numObservations*numTimeSteps);
end

q-risk Metric Function

The qRiskMetric function computes the q-risk for the specified quantiles:

q-risk=2observationstime stepsQL(y,ypred,q)observationstime steps|y|.

You can use the q-risk to compare the results to Ref. [1] and the other papers it references. The results in this example do not match the results in Ref. [1] because this example trains on a smaller subset of the ElectricityLoadDiagrams20112014 data set.

function qRisk = qRiskMetric(Y,T,quantiles)
predictionUnderflow = T - Y;
weightedErrors = quantiles' .* max(predictionUnderflow,0) + (1 - quantiles') .* max(-predictionUnderflow,0);

quantileLoss = mean(weightedErrors, [1 3]);
normalizer = mean(abs(T), [1 3]);

qRisk = 2*quantileLoss/normalizer;
end

Denormalization Function

The denormalize function denormalizes the predictions made by the TFT.

function denormalizedPredictions = denormalize(predictions,ids,trainingMean,trainingStd)
mu = shiftdim(trainingMean{:,double(ids)},-1);
sigma = shiftdim(trainingStd{:,double(ids)},-1);

denormalizedPredictions = predictions .* sigma + mu;
end

See Also

| | | |

Topics