Main Content

fit

Train drift-aware learner for incremental learning with new data

Since R2022b

    Description

    Mdl = fit(Mdl,X,Y) returns an incremental drift-aware learning model Mdl, which represents the input incremental drift-aware learning model Mdl trained using the predictor and response data, X and Y, respectively.

    fit does not update Mdl.Metrics.

    example

    Mdl = fit(Mdl,X,Y,Name=Value) uses additional options specified by one or more name-value arguments. For example, you can specify that the columns of the predictor data matrix correspond to observations, and set observation weights.

    example

    Examples

    collapse all

    Load the human activity dataset. Randomly shuffle the data.

    load humanactivity;
    n = numel(actid);
    rng(1) % For reproducibility
    idx = randsample(n,n);

    For details on the data set, enter Description at the command line.

    Define the predictor and response variables.

    X = feat(idx,:);
    Y = actid(idx);

    Responses can be one of five classes: Sitting, Standing, Walking, Running, or Dancing.

    Dichotomize the response by identifying whether the subject is moving (actid > 2).

    Y = Y > 2;

    Flip labels for the second half of the dataset to simulate drift.

    Y(floor(numel(Y)/2):end,:) = ~Y(floor(numel(Y)/2):end,:);

    Initiate a default incremental drift-aware model for classification as follows:

    1. Create an incremental linear SVM model for binary classification. Specify an estimation period of 5000 observations and the SGD solver.

    2. Initiate a default incremental drift-aware model using the incremental linear SVM model as the base learner. Specify a training period of 5000 observations.

    baseMdl = incrementalClassificationLinear(EstimationPeriod=5000,Solver="sgd");
    idaMdl = incrementalDriftAwareLearner(baseMdl,TrainingPeriod=5000);

    idaMdl is an incrementalDriftAwareLearner model. All its properties are read-only. By default, incrementalDriftAwareLearner uses the Hoeffding's Bound drift detection method based on moving averages ("hddma").

    idaMdl must be fit to data before you can use it to perform any other operations.

    Fit the incremental drift-aware model to the training data, in chunks of 50 observations at a time, by using the fit function. At each iteration:

    1. Simulate a data stream by processing 50 observations.

    2. Overwrite the previous incremental model with a new one fitted to the incoming observations.

    3. Store the number of training observations, and the prior probability of whether the subject moved (Y = true) to see how they evolve during incremental training.

    % Preallocation
    numObsPerChunk = 50;
    nchunk = floor(n/numObsPerChunk);
    beta1 = zeros(nchunk,1);    
    numtrainobs = zeros(nchunk,1);
    dstatus = zeros(nchunk,1);
    statusname = strings(nchunk,1);
    driftTimes = [];
    ce = array2table(zeros(nchunk,2),VariableNames=["Cumulative" "Window"]);
    
    % Incremental fitting
    for j = 1:nchunk
        ibegin = min(n,numObsPerChunk*(j-1) + 1);
        iend   = min(n,numObsPerChunk*j);
        idx = ibegin:iend;    
    
        idaMdl = fit(idaMdl,X(idx,:),Y(idx));
        idaMdl = updateMetrics(idaMdl,X(idx,:),Y(idx));
        beta1(j) = idaMdl.BaseLearner.Beta(1);
        
        % Record drift status and classification error
        statusname(j) = string(idaMdl.DriftStatus); 
        ce{j,:} = idaMdl.Metrics{"ClassificationError",:};
        numtrainobs(j) = idaMdl.NumTrainingObservations; 
    
        if idaMdl.DriftDetected
           dstatus(j) = 2;  
           driftTimes(end+1) = j; 
        elseif idaMdl.WarningDetected
           dstatus(j) = 1;
        else 
           dstatus(j) = 0;
        end   
     
    end

    idaMdl is an incrementalDriftAwareLearner model object trained on all the data in the stream.

    To see how the parameters evolve during incremental learning, plot them on separate tiles.

    tiledlayout(2,1)
    set(groot,DefaultConstantLineLineWidth=1.5);
    nexttile
    plot(beta1)
    ylabel("\beta_1")
    xline(idaMdl.BaseLearner.EstimationPeriod/numObsPerChunk,"r-.","EstimationPeriod")
    xline(idaMdl.BaseLearner.EstimationPeriod/numObsPerChunk + driftTimes,"r-.")
    xlabel('Iteration')
    xline((idaMdl.BaseLearner.EstimationPeriod+idaMdl.TrainingPeriod)/numObsPerChunk, ...
        "b-.",{"Estimation +","Training Period"},LabelVerticalAlignment="middle")
    xline(floor(numel(Y)/2)/numObsPerChunk,"m--","Drift", ...
        LabelVerticalAlignment="middle")
    
    nexttile
    plot(numtrainobs)
    ylabel("Number of Training Observations")
    xline(idaMdl.BaseLearner.EstimationPeriod/numObsPerChunk,"r-.","EstimationPeriod")
    xline(idaMdl.BaseLearner.EstimationPeriod/numObsPerChunk + driftTimes,"r-.")
    xlabel("Iteration")
    xline((idaMdl.BaseLearner.EstimationPeriod+idaMdl.TrainingPeriod)/numObsPerChunk, ...
        "b-.",{"Estimation +","Training Period"},LabelVerticalAlignment="middle")
    xline(floor(numel(Y)/2)/numObsPerChunk,"m--","Drift", ...
        LabelVerticalAlignment="middle")

    Figure contains 2 axes objects. Axes object 1 with xlabel Iteration, ylabel \beta_1 contains 5 objects of type line, constantline. Axes object 2 with xlabel Iteration, ylabel Number of Training Observations contains 5 objects of type line, constantline.

    The plot suggests that fit does not fit the model to the data or update the parameters until after the estimation period. After a drift is detected, the function waits for another Mdl.BaseLearner.EstimationPeriod number of observations to fit the new model to data.

    Plot the cumulative and per window classification error. Mark the warmup and training periods, and where the drift was introduced.

    figure()
    h = plot(ce.Variables);
    
    xlim([0 nchunk])
    ylabel("Classification Error")
    xlabel("Iteration")
    
    xline((idaMdl.BaseLearner.EstimationPeriod+idaMdl.MetricsWarmupPeriod)/numObsPerChunk, ...
        "g-.","Estimation + Warmup Period")
    xline((idaMdl.BaseLearner.EstimationPeriod+idaMdl.MetricsWarmupPeriod)/numObsPerChunk+ ...
        driftTimes,"g-.","Estimation + Warmup Period")
    xline((idaMdl.BaseLearner.EstimationPeriod+idaMdl.TrainingPeriod)/numObsPerChunk, ...
        "b-.","Estimation + Training Period",LabelVerticalAlignment="middle")
    xline(driftTimes,"m--","Drift",LabelVerticalAlignment="middle")
    
    legend(h,ce.Properties.VariableNames)
    legend(h,Location="north")

    Figure contains an axes object. The axes object with xlabel Iteration, ylabel Classification Error contains 6 objects of type line, constantline. These objects represent Cumulative, Window.

    Plot the drift status versus the iteration number.

    gscatter(1:nchunk,dstatus,statusname,"gbr","o",5,"on","Iteration","Drift Status","filled")

    Figure contains an axes object. The axes object with xlabel Iteration, ylabel Drift Status contains 2 objects of type line. One or more of the lines displays its values using only markers These objects represent Stable, Drift.

    Predict labels for the second half of the data and check the accuracy of the model updated after the drift.

    n = floor(numel(Y)/2);
    yhat = predict(idaMdl,X(n:end,:));
    accuracy = sum(Y(n:end)==yhat)/n
    accuracy = 
    0.9960
    

    Load the robotarm data set. Obtain the sample size n and the number of predictor variables p.

    load robotarm
    n = numel(ytrain);
    p = size(Xtrain,2);

    For details on the data set, enter Description at the command line.

    Introduce an artificial drift to the response variable between observations 2500 and 5000.

    Y=ytrain;
    j=1.25;
    for i=2500:1250:5000
        idx=min(i+1250,5000);
        Y(i:idx)=ytrain(i:idx)*j;
        j=j+0.25;
    end

    Initiate an incremental drift-aware model for regression as follows:

    1. Create an incremental linear SVM model for regression. Specify an estimation period of 500 observations and the SGD solver.

    2. Create an incremental drift detector for continuous data.

    3. Initiate an incremental drift-aware model using the incremental linear SVM model as the base learner and the drift detector you created. Specify a training period of 2000.

    baseMdl = incrementalRegressionLinear(EstimationPeriod=500,Solver="sgd",MetricsWarmUpPeriod=750);
    ddetector = incrementalConceptDriftDetector("hddma",InputType="continuous",Alternative="greater");
    idaMdl = incrementalDriftAwareLearner(baseMdl,DriftDetector=ddetector,TrainingPeriod=2000);

    idaMdl is an incrementalDriftAwareLearner model. All its properties are read-only.

    Preallocate the number of variables in each chunk and number of iterations for creating a stream of data.

    numObsPerChunk = 10;
    nchunk = floor(n/numObsPerChunk);

    Preallocate the variables for tracking the drift status and drift time, and storing the regression error and number of training observations.

    dstatus = zeros(nchunk,1);
    statusname = strings(nchunk,1);
    driftTimes = [];
    
    ei = array2table(nan(nchunk,2),VariableNames=["Cumulative","Window"]);
    numtrainobs = zeros(nchunk,1);

    Perform incremental learning on the rest of the data by using the updateMetrics and fit functions. At each iteration:

    1. Simulate a data stream by processing 10 observations at a time.

    2. Call updateMetrics to update the cumulative and window classification error of the model given the incoming chunk of observations. Overwrite the previous incremental model to update the losses in the Metrics property. Note that the function does not fit the model to the chunk of new data. Specify the observation orientation.

    3. Call fit to fit the model to the incoming chunk of observations. Overwrite the previous incremental model to update the model parameters. Specify the observation orientation.

    4. Store the regression error and number of training observations.

    rng(123) % For reproducibility
    for j = 1:nchunk
    
        ibegin = min(n,numObsPerChunk*(j-1) + 1);
        iend   = min(n,numObsPerChunk*j);
        idx = ibegin:iend;
    
        idaMdl = updateMetrics(idaMdl,Xtrain(idx,:),Y(idx),ObservationsIn="rows");
        ei{j,:} = idaMdl.Metrics{"EpsilonInsensitiveLoss",:};
    
        idaMdl = fit(idaMdl,Xtrain(idx,:),Y(idx),ObservationsIn="rows");
        numtrainobs(j) = idaMdl.NumTrainingObservations;
    
        statusname(j) = string(idaMdl.DriftStatus);
        if idaMdl.DriftDetected
           dstatus(j) = 2;
           driftTimes(end+1) = j;
        elseif idaMdl.WarningDetected
           dstatus(j) = 1;
        else 
           dstatus(j) = 0;
        end   
       
    end

    idaMdl is an incrementalDriftAwareModel object trained on all the data in the stream.

    Plot a trace plot of the number of training observations and the performance metrics. Mark the times for estimation period, warm up metric period, and training period.

    t = tiledlayout(2,1);
    nexttile
    plot(numtrainobs)
    xline(idaMdl.BaseLearner.EstimationPeriod/numObsPerChunk,"g-.","Estimation Period")
    xline((idaMdl.BaseLearner.EstimationPeriod+idaMdl.MetricsWarmupPeriod)/numObsPerChunk,"m-.","Warmup Period")
    xline((idaMdl.BaseLearner.EstimationPeriod+idaMdl.TrainingPeriod)/numObsPerChunk,"b--","Training Period")
    
    xline(idaMdl.BaseLearner.EstimationPeriod/numObsPerChunk+driftTimes,"g-.")
    xline((idaMdl.BaseLearner.EstimationPeriod+idaMdl.MetricsWarmupPeriod)/numObsPerChunk+driftTimes,"m-.")
    xline((idaMdl.BaseLearner.EstimationPeriod+idaMdl.TrainingPeriod)/numObsPerChunk+driftTimes,"b--")
    xline(driftTimes,"r","Drift",LabelVerticalAlignment="middle",LineWidth=1.5)
    xlim([0 nchunk])
    ylabel("Number of Training Observations")
    
    nexttile
    plot(ei.Variables)
    xline(idaMdl.BaseLearner.EstimationPeriod/numObsPerChunk,"g-.","Estimation Period")
    xline((idaMdl.MetricsWarmupPeriod+idaMdl.BaseLearner.EstimationPeriod)/numObsPerChunk,"m-.","Warmup Period")
    xline((idaMdl.BaseLearner.EstimationPeriod+idaMdl.TrainingPeriod)/numObsPerChunk,"b--","Training Period")
    
    xline(idaMdl.BaseLearner.EstimationPeriod/numObsPerChunk+driftTimes,"g-.")
    xline((idaMdl.BaseLearner.EstimationPeriod+idaMdl.MetricsWarmupPeriod)/numObsPerChunk+driftTimes,"m-.")
    xline((idaMdl.BaseLearner.EstimationPeriod+idaMdl.TrainingPeriod)/numObsPerChunk+driftTimes,"b--")
    xline(driftTimes,"r","Drift",LabelVerticalAlignment="middle",LineWidth=1.5)
    xlim([0 nchunk])
    legend(ei.Properties.VariableNames,Location="northeast")
    ylabel("Regression Error")
    xlabel(t,"Iteration")

    Figure contains 2 axes objects. Axes object 1 with ylabel Number of Training Observations contains 8 objects of type line, constantline. Axes object 2 with ylabel Regression Error contains 9 objects of type line, constantline. These objects represent Cumulative, Window.

    Plot the drift status versus the iteration number.

    figure()
    gscatter(1:nchunk,dstatus,statusname,'gmr','*',5,'on',"Iteration","Drift Status")

    Figure contains an axes object. The axes object with xlabel Iteration, ylabel Drift Status contains 3 objects of type line. One or more of the lines displays its values using only markers These objects represent Stable, Warning, Drift.

    Input Arguments

    collapse all

    Incremental drift-aware learning model fit to streaming data, specified as an incrementalDriftAwareLearner model object. You can create Mdl using the incrementalDriftAwareLearner function. For more details, see the object reference page.

    Chunk of predictor data to which the model is fit, specified as a floating-point matrix of n observations and Mdl.BaseLearner.NumPredictors predictor variables.

    When Mdl.BaseLearner accepts the ObservationsIn name-value argument, the value of ObservationsIn determines the orientation of the variables and observations. The default ObservationsIn value is "rows", which indicates that observations in the predictor data are oriented along the rows of X.

    The length of the observation responses (or labels) Y and the number of observations in X must be equal; Y(j) is the response (or label) of observation j (row or column) in X.

    Note

    • If Mdl.BaseLearner.NumPredictors = 0, fit infers the number of predictors from X, and sets the corresponding property of the output model. Otherwise, if the number of predictor variables in the streaming data changes from Mdl.BaseLearner.NumPredictors, fit issues an error.

    • fit supports only floating-point input predictor data. If your input data includes categorical data, you must prepare an encoded version of the categorical data. Use dummyvar to convert each categorical variable to a numeric matrix of dummy variables. Then, concatenate all dummy variable matrices and any other numeric predictors. For more details, see Dummy Variables.

    Data Types: single | double

    Chunk of responses (or labels) to which the model is fit, specified as one of the following:

    • Floating-point vector of n elements for regression models, where n is the number of rows in X.

    • Categorical, character, or string array, logical vector, or cell array of character vectors for classification models. If Y is a character array, it must have one class label per row. Otherwise, Y must be a vector with n elements.

    The length of Y and the number of observations in X must be equal; Y(j) is the response (or label) of observation j (row or column) in X.

    For classification problems:

    • When Mdl.BaseLearner.ClassNames is nonempty, the following conditions apply:

      • If Y contains a label that is not a member of Mdl.BaseLearner.ClassNames, fit issues an error.

      • The data type of Y and Mdl.BaseLearner.ClassNames must be the same.

    • When Mdl.BaseLearner.ClassNames is empty, fit infers Mdl.BaseLearner.ClassNames from data.

    Data Types: single | double | categorical | char | string | logical | cell

    Name-Value Arguments

    collapse all

    Specify optional pairs of arguments as Name1=Value1,...,NameN=ValueN, where Name is the argument name and Value is the corresponding value. Name-value arguments must appear after other arguments, but the order of the pairs does not matter.

    Example: ObservationsIn="columns",Weights=W specifies that the columns of the predictor matrix correspond to observations, and the vector W contains observation weights to apply during incremental learning.

    Predictor data observation dimension, specified as "columns" or "rows".

    fit supports ObservationsIn only if Mdl.BaseLearner supports the ObservationsIn name-value argument.

    Example: ObservationsIn="columns"

    Data Types: char | string

    Chunk of observation weights, specified as a floating-point vector of positive values. fit weighs the observations in X with the corresponding values in Weights. The size of Weights must equal n, which is the number of observations in X.

    By default, Weights is ones(n,1).

    Example: Weights=w

    Data Types: double | single

    Output Arguments

    collapse all

    Updated incremental drift-aware learning model, returned as an incremental drift-aware learning model object of the same data type as the input model Mdl, incrementalDriftAwareLearner.

    If Mdl.BaseLearner.EstimationPeriod > 0, the incremental fitting functions updateMetricsAndFit and fit estimate hyperparameters using the first Mdl.BaseLearner.EstimationPeriod observations passed to either function; they do not train the input model to the data. However, if an incoming chunk of n observations is greater than or equal to the number of observations remaining in the estimation period m, fit estimates hyperparameters using the first nm observations, and fits the input model to the remaining m observations.

    For classification problems, if the ClassNames property of the input model Mdl.BaseLearner is an empty array, fit sets the ClassNames property of the output model Mdl.BaseLearner to unique(Y).

    Algorithms

    collapse all

    References

    [1] Barros, Roberto S.M. , et al. "RDDM: Reactive drift detection method." Expert Systems with Applications. vol. 90, Dec. 2017, pp. 344-55. https://doi.org/10.1016/j.eswa.2017.08.023.

    [2] Bifet, Albert, et al. "New Ensemble Methods for Evolving Data Streams." Proceedings of the 15th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining. ACM Press, 2009, p. 139. https://doi.org/10.1145/1557019.1557041.

    [3] Gama, João, et al. "Learning with drift detection". Advances in Artificial Intelligence – SBIA 2004, edited by Ana L. C. Bazzan and Sofiane Labidi, vol. 3171, Springer Berlin Heidelberg, 2004, pp. 286–95. https://doi.org/10.1007/978-3-540-28645-5_29.

    Version History

    Introduced in R2022b