predict
Description
returns the predicted responses for the predictor data in the matrix or table
predictedY
= predict(Mdl
,X
)X
using the trained multiresponse regression model
Mdl
. For more information, see Prediction with Regression Chain Ensembles.
By default, predict
returns the predicted responses as a matrix.
That is, this syntax is equivalent to
predict(Mdl,X,OutputType="matrix")
.
returns the predicted responses as a table.predictedY
= predict(Mdl
,X
,OutputType="table")
Examples
Train Multiresponse Regression Model with Regression Chains
Create a regression model with more than one response variable by using fitrchains
.
Load the carbig
data set, which contains measurements of cars made in the 1970s and early 1980s. Create a table containing the predictor variables Displacement
, Horsepower
, and so on, as well as the response variables Acceleration
and MPG
. Display the first eight rows of the table.
load carbig cars = table(Displacement,Horsepower,Model_Year, ... Origin,Weight,Acceleration,MPG); head(cars)
Displacement Horsepower Model_Year Origin Weight Acceleration MPG ____________ __________ __________ _______ ______ ____________ ___ 307 130 70 USA 3504 12 18 350 165 70 USA 3693 11.5 15 318 150 70 USA 3436 11 18 304 150 70 USA 3433 12 16 302 140 70 USA 3449 10.5 17 429 198 70 USA 4341 10 15 454 220 70 USA 4354 9 14 440 215 70 USA 4312 8.5 14
Categorize the cars based on whether they were made in the USA.
cars.Origin = categorical(cellstr(cars.Origin)); cars.Origin = mergecats(cars.Origin,["France","Japan",... "Germany","Sweden","Italy","England"],"NotUSA");
Partition the data into training and test sets. Use approximately 85% of the observations to train a multiresponse model, and 15% of the observations to test the performance of the trained model on new data. Use cvpartition
to partition the data.
rng("default") % For reproducibility c = cvpartition(height(cars),"Holdout",0.15); carsTrain = cars(training(c),:); carsTest = cars(test(c),:);
Train a multiresponse regression model by passing the carsTrain
training data to the fitrchains
function. By default, the function uses bagged ensembles of trees in the regression chains.
Mdl = fitrchains(carsTrain,["Acceleration","MPG"])
Mdl = RegressionChainEnsemble PredictorNames: {'Displacement' 'Horsepower' 'Model_Year' 'Origin' 'Weight'} ResponseName: ["Acceleration" "MPG"] CategoricalPredictors: 4 NumChains: 2 LearnedChains: {2x2 cell} NumObservations: 338
Mdl
is a trained RegressionChainEnsemble
model object. You can use dot notation to access the properties of Mdl
. For example, you can specify Mdl.Learners
to see the bagged ensembles used to train the model.
Evaluate the performance of the regression model on the test set by computing the test mean squared error (MSE). Smaller MSE values indicate better performance. Return the loss for each response variable separately by setting the OutputType
name-value argument to "per-response"
.
testMSE = loss(Mdl,carsTest,["Acceleration","MPG"], ... OutputType="per-response")
testMSE = 1×2
2.4921 9.0568
Predict the response values for the observations in the test set. Return the predicted response values as a table.
predictedY = predict(Mdl,carsTest,OutputType="table")
predictedY=60×2 table
Acceleration MPG
____________ ______
12.573 16.109
10.78 13.988
11.282 12.963
15.185 21.066
12.203 13.773
13.216 14.216
17.117 30.199
16.478 29.033
13.439 14.208
11.552 13.066
13.398 13.271
14.848 20.927
16.552 24.603
12.501 15.359
15.778 19.328
12.343 13.185
⋮
Input Arguments
Mdl
— Multiresponse regression model
RegressionChainEnsemble
object | CompactRegressionChainEnsemble
object
Multiresponse regression model, specified as a RegressionChainEnsemble
or CompactRegressionChainEnsemble
object.
X
— Predictor data
numeric matrix | table
Predictor data, specified as a numeric matrix or a table. Each row of
X
corresponds to one observation, and each column corresponds to
one predictor. X
must have the same data type as the predictor data
used to train Mdl
, and must contain the same predictors.
Data Types: single
| double
| table
Output Arguments
predictedY
— Predicted responses
numeric matrix | numeric table
Predicted responses, returned as a numeric matrix or table. For observation
i
in X
and response variable
j
, the value predictedY(i,j)
is the mean
predicted response value across the regression chains.
For more information, see Prediction with Regression Chain Ensembles.
Algorithms
Prediction with Regression Chain Ensembles
A regression chain is a sequence of regression models in which the response variables for previous models become predictor variables for subsequent models. If the training data consists of p predictor variables and k response variables, then a regression chain includes exactly k models, each with a different response variable. The first model has p predictors, the second model has p+1 predictors, and so on, with the last model having p+k–1 predictors.
Mdl
is a regression chain ensemble, where each row of
Mdl.
corresponds to one regression chain. Each entry in
Learners
Mdl.Learners
is a compact regression model object. Each model produces
predictions for one response variable. For example, for regression chain
i
:
The first model (
Mdl.Learners{i,1}
) uses thepredict
object function of the compact object to predict values for response variableMdl.
by using the predictor data inChainOrders
(i,1)X
.The second model (
Mdl.Learners{i,2}
) uses thepredict
object function to predict values for response variableMdl.ChainOrders(i,2)
by using the predictor data inX
and the predicted response values returned byMdl.Learners{i,1}
.The process repeats for each model in the regression chain, so that each response variable has a set of predicted response values.
After each regression chain produces predicted responses, the software averages the
results to return predictedY
.
References
[1] Spyromitros-Xioufis, Eleftherios, Grigorios Tsoumakas, William Groves, and Ioannis Vlahavas. "Multi-Target Regression via Input Space Expansion: Treating Targets as Inputs." Machine Learning 104, no. 1 (July 2016): 55–98. https://doi.org/10.1007/s10994-016-5546-z.
Version History
Introduced in R2024b
See Also
fitrchains
| loss
| RegressionChainEnsemble
| CompactRegressionChainEnsemble
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)
Asia Pacific
- Australia (English)
- India (English)
- New Zealand (English)
- 中国
- 日本Japanese (日本語)
- 한국Korean (한국어)