Custom loss functions

As an example, we will implement ridge regularization. Maximum likelihood estimation with ridge regularization consists of optimizing the objective

\[F_{ML}(\theta) + \alpha \lVert \theta_I \rVert^2_2\]

Since we allow for the optimization of sums of loss functions, and the maximum likelihood loss function already exists, we only need to implement the ridge part (and additionally get ridge regularization for WLS and FIML estimation for free).

Minimal

To define a new loss function, you have to define a new type that is a subtype of AbstractLoss:

struct MyRidge <: AbstractLoss
    α
    I
end

We store the hyperparameter α and the indices I of the parameters we want to regularize.

Additionaly, we need to define a method of the function evaluate! to compute the objective:

import StructuralEquationModels: evaluate!

evaluate!(objective::Number, gradient::Nothing, hessian::Nothing, ridge::MyRidge, par) =
    ridge.α * sum(i -> abs2(par[i]), ridge.I)
evaluate! (generic function with 10 methods)

The function evaluate! recognizes by the types of the arguments objective, gradient and hessian whether it should compute the objective value, gradient or hessian of the model w.r.t. the parameters. In this case, gradient and hessian are of type Nothing, signifying that they should not be computed, but only the objective value.

That's all we need to make it work! For example, we can now fit A first model with ridge regularization:

We first give some parameters labels to be able to identify them as targets for the regularization:

observed_vars = [:x1, :x2, :x3, :y1, :y2, :y3, :y4, :y5, :y6, :y7, :y8]
latent_vars = [:ind60, :dem60, :dem65]

graph = @StenoGraph begin

    # loadings
    ind60 → fixed(1)*x1 + x2 + x3
    dem60 → fixed(1)*y1 + y2 + y3 + y4
    dem65 → fixed(1)*y5 + y6 + y7 + y8

    # latent regressions
    ind60 → label(:a)*dem60
    dem60 → label(:b)*dem65
    ind60 → label(:c)*dem65

    # variances
    _(observed_vars) ↔ _(observed_vars)
    _(latent_vars) ↔ _(latent_vars)

    # covariances
    y1 ↔ y5
    y2 ↔ y4 + y6
    y3 ↔ y7
    y8 ↔ y4 + y6

end

partable = ParameterTable(
    graph,
    latent_vars = latent_vars,
    observed_vars = observed_vars
)

parameter_indices = getindex.([param_indices(partable)], [:a, :b, :c])
myridge = MyRidge(0.01, parameter_indices)

model = SemFiniteDiff(
    specification = partable,
    data = example_data("political_democracy"),
    loss = (SemML, myridge)
)

model_fit = fit(model)
Fitted Structural Equation Model 
=============================================== 
--------------------- Model ------------------- 

Structural Equation Model : Finite Difference Approximation- Loss Functions 
  > SemML
    - observed:    SemObservedData 
    - implied:     RAM 
  > MyRidge

------------- Optimization result ------------- 

engine: Optim

 * Status: success

 * Candidate solution
    Final objective value:     2.123542e+01

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 3.73e-05 ≰ 1.5e-08
    |x - x'|/|x'|          = 5.00e-06 ≰ 0.0e+00
    |f(x) - f(x')|         = 1.59e-09 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 7.49e-11 ≤ 1.0e-10
    |g(x)|                 = 1.54e-04 ≰ 1.0e-08

 * Work counters
    Seconds run:   1  (vs limit Inf)
    Iterations:    189
    f(x) calls:    229
    ∇f(x) calls:   229
    ∇f(x)ᵀv calls: 0

This is one way of specifying the model - we now have one model with multiple loss functions. Because we did not provide a gradient for MyRidge, we have to specify a SemFiniteDiff model that computes numerical gradients with finite difference approximation.

Ridge regularization only depends on the parameters, so the evaluate! method above does not need anything else. Other loss functions, however, depend on the observed data and on what the model implies about it. Loss functions that compare the implied and the observed structure are subtypes of SemLoss and store their own observed and implied parts, which can be accessed inside evaluate! via observed(loss) and implied(loss). See Second example - maximum likelihood for information on how to do that.

Improve performance

By far the biggest improvements in performance will result from specifying analytical gradients. We can do this for our example:

function evaluate!(objective, gradient, hessian::Nothing, ridge::MyRidge, par)
    # compute gradient
    if !isnothing(gradient)
        fill!(gradient, 0)
        gradient[ridge.I] .= 2 * ridge.α * par[ridge.I]
    end
    # compute objective
    if !isnothing(objective)
        return ridge.α * sum(i -> par[i]^2, ridge.I)
    end
end
evaluate! (generic function with 11 methods)

As you can see, in this method definition, both objective and gradient can be different from nothing. We then check whether to compute the objective value and/or the gradient with isnothing(objective)/isnothing(gradient). This syntax makes it possible to compute objective value and gradient at the same time, which is beneficial when the the objective and gradient share common computations.

Now, instead of specifying a SemFiniteDiff, we can use the normal Sem constructor:

model_new = Sem(
    specification = partable,
    data = example_data("political_democracy"),
    loss = (SemML, myridge)
)

model_fit = fit(model_new)
Fitted Structural Equation Model 
=============================================== 
--------------------- Model ------------------- 

Structural Equation Model- Loss Functions 
  > SemML
    - observed:    SemObservedData 
    - implied:     RAM 
  > MyRidge

------------- Optimization result ------------- 

engine: Optim

 * Status: success

 * Candidate solution
    Final objective value:     2.123542e+01

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 2.90e-05 ≰ 1.5e-08
    |x - x'|/|x'|          = 3.89e-06 ≰ 0.0e+00
    |f(x) - f(x')|         = 1.85e-09 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 8.69e-11 ≤ 1.0e-10
    |g(x)|                 = 1.13e-04 ≰ 1.0e-08

 * Work counters
    Seconds run:   0  (vs limit Inf)
    Iterations:    170
    f(x) calls:    215
    ∇f(x) calls:   215
    ∇f(x)ᵀv calls: 0

The results are the same, but we can verify that the computational costs are way lower (for this, the julia package BenchmarkTools has to be installed):

using BenchmarkTools

@benchmark fit(model)

@benchmark fit(model_new)

The exact results of those benchmarks are of course highly depended an your system (processor, RAM, etc.), but you should see that the median computation time with analytical gradients drops to about 5% of the computation without analytical gradients.

Additionally, you may provide analytic hessians by writing a respective method for evaluate!. However, this will only matter if you use an optimization algorithm that makes use of the hessians. Our default algorithmn LBFGS from the package Optim.jl does not use hessians (for example, the Newton algorithmn from the same package does).

Convenient

In the minimal example above we built myridge ourselves and passed the ready-made instance to the model via loss = (SemML, myridge). Alternatively, you can let the outer Sem constructor build the loss term for you: pass the loss type instead of an instance and provide a keyword constructor.

MyRidge(; α_ridge, which_ridge, kwargs...) = MyRidge(α_ridge, which_ridge)

Any keyword arguments passed to Sem(...) are forwarded to this constructor (along with some that the model supplies automatically, such as nparams), so the loss can be configured directly from the model call:

model = SemFiniteDiff(
    specification = partable,
    data = example_data("political_democracy"),
    loss = (SemML, MyRidge),
    α_ridge = 0.01,
    which_ridge = parameter_indices,
)

Note that, being a plain AbstractLoss, MyRidge neither stores nor receives an observed or implied part — it depends only on the parameters. SEM-specific loss functions are constructed differently; see Second example - maximum likelihood.

Additional functionality

Access additional information

If you want to provide a way to query information about loss functions of your type, you can provide functions for that:

hyperparameter(ridge::MyRidge) = ridge.α
regularization_indices(ridge::MyRidge) = ridge.I

Second example - maximum likelihood

Let's make a sligtly more complicated example: we will reimplement maximum likelihood estimation.

To keep it simple, we only cover models without a meanstructure. The maximum likelihood objective is defined as

\[F_{ML} = \log \det \Sigma_i + \mathrm{tr}\left(\Sigma_{i}^{-1} \Sigma_o \right)\]

where $\Sigma_i$ is the model implied covariance matrix and $\Sigma_o$ is the observed covariance matrix. We can query the model implied covariance matrix from the implied part of our loss term, and the observed covariance matrix from the observed part of our loss term.

Since this loss function compares the implied and the observed structure, it is a subtype of SemLoss rather than a plain AbstractLoss. Every SemLoss stores its own observed and implied parts, which can be accessed inside evaluate! via observed(loss) and implied(loss).

To get information on what we can access from a certain implied or observed type, we can check it`s documentation an the pages API - model parts or via the help mode of the REPL:

julia>?

help?> RAM

help?> SemObservedData

We see that the model implied covariance matrix can be assessed as implied(loss).Σ and the observed covariance matrix as obs_cov(observed(loss)).

Unlike a plain AbstractLoss, a SemLoss subtype stores its observed and implied parts (in its first two fields), and the Sem constructor builds it for you. To support this, every SemLoss subtype should provide a constructor with three positional arguments:

  • observed::SemObserved: the observed part of the loss term
  • implied::SemImplied: the implied part of the loss term
  • refloss::Union{MaximumLikelihood, Nothing} = nothing: an optional existing loss term of the same type, used as a reference for any loss-specific configuration.

Any additional configuration is passed as optional keyword arguments; if both refloss and keyword arguments are given, the keyword arguments take precedence. This constructor is also used by replace_observed to rebuild the loss term with new observed data while sharing the implied state. With this, we can implement maximum likelihood optimization as

struct MaximumLikelihood{O <: SemObserved, I <: SemImplied} <: SemLoss{O, I}
    observed::O
    implied::I
end

# constructor used by the `Sem` constructor to build the loss term
MaximumLikelihood(observed::SemObserved, implied::SemImplied, refloss = nothing; kwargs...) =
    MaximumLikelihood{typeof(observed), typeof(implied)}(observed, implied)

using LinearAlgebra
import StructuralEquationModels: evaluate!

function evaluate!(objective::Number, gradient::Nothing, hessian::Nothing, semml::MaximumLikelihood, par)
    # access the model implied and observed covariance matrices
    Σᵢ = implied(semml).Σ
    Σₒ = obs_cov(observed(semml))
    # compute the objective
    if isposdef(Symmetric(Σᵢ)) # is the model implied covariance matrix positive definite?
        return logdet(Σᵢ) + tr(inv(Σᵢ)*Σₒ)
    else
        return Inf
    end
end
evaluate! (generic function with 12 methods)

to deal with eventual non-positive definiteness of the model implied covariance matrix, we chose the pragmatic way of returning infinity whenever this is the case.

Let's specify and fit a model:

model_ml = SemFiniteDiff(
    specification = partable,
    data = example_data("political_democracy"),
    loss = MaximumLikelihood
)

model_fit = fit(model_ml)
Fitted Structural Equation Model 
=============================================== 
--------------------- Model ------------------- 

Structural Equation Model : Finite Difference Approximation- Loss Functions 
  > MaximumLikelihood
    - observed:    SemObservedData 
    - implied:     RAM 

------------- Optimization result ------------- 

engine: Optim

 * Status: success

 * Candidate solution
    Final objective value:     2.120543e+01

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 1.81e-05 ≰ 1.5e-08
    |x - x'|/|x'|          = 2.42e-06 ≰ 0.0e+00
    |f(x) - f(x')|         = 1.23e-09 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 5.79e-11 ≤ 1.0e-10
    |g(x)|                 = 1.49e-04 ≰ 1.0e-08

 * Work counters
    Seconds run:   0  (vs limit Inf)
    Iterations:    188
    f(x) calls:    232
    ∇f(x) calls:   232
    ∇f(x)ᵀv calls: 0

Supporting replace_observed

replace_observed swaps the observed data of a model while keeping the rest of the model (specification, implied type, loss configuration) intact. It is the backbone of Simulation studies and the bootstrap, where the same model is fitted to many datasets and rebuilding it from scratch each time would be wasteful.

The default mechanism

For a SemLoss term, the generic implementation rebuilds the term by calling its three-argument constructor with the new observed data, the original implied part, and the original loss term as refloss:

# simplified; see src/loss/abstract.jl
function replace_observed(loss::SemLoss, new_observed::SemObserved; kwargs...)
    loss_ctor = typeof(loss).name.wrapper           # e.g. `MaximumLikelihood`
    return loss_ctor(new_observed, implied(loss), loss)  # third arg is the `refloss`
end

This is exactly the three-argument constructor every SemLoss already provides (see Second example - maximum likelihood). The refloss argument is what makes this work without re-deriving anything: the new term inherits the loss-specific configuration from the reference term and shares its implied state (and, where applicable, internal buffers). The implied part is shared rather than copied because it depends only on the model specification, not on the data.

Because of this, a loss term that does not cache anything derived from the observed data needs no extra code — implementing the three-argument constructor is enough, and replace_observed works out of the box. MaximumLikelihood above is such a case: it reads obs_cov(observed(loss)) on every evaluation and stores nothing, so it even ignores refloss entirely and is already fully compatible.

Plain AbstractLoss terms (no observed part)

The mechanism above only applies to SemLoss terms, which carry an observed part. A plain AbstractLoss term — like the MyRidge regularizer from the Minimal example — depends only on the parameters and has no notion of observed data. There is therefore nothing to swap, and replace_observed returns such terms unchanged:

# src/loss/abstract.jl — fallback for non-SEM loss terms
replace_observed(loss::AbstractLoss, ::Any; kwargs...) = loss

This is handled by the default fallback, so you do not need to write anything for your own AbstractLoss types: when a model mixes SEM and non-SEM loss terms (e.g. loss = (SemML, MyRidge)), replace_observed rebuilds the SemML term with the new data and carries the MyRidge term over as-is. The recompute_observed_state keyword is likewise accepted and ignored.

If your regularizer does need to know about the data, the idiomatic solution is to make it a SemLoss (so it owns an observed part and participates in the rebuild) rather than to specialize replace_observed on a plain AbstractLoss.

When you need a custom method

You need to specialize replace_observed when your loss term precomputes and stores a quantity derived from the observed data. The default mechanism inherits that quantity from the refloss, so after swapping in new data the cached value would be stale.

SemWLS is the canonical example. Its weight matrix V defaults to the GLS weights computed from the observed covariance matrix and is stored on the term. If replace_observed simply reused refloss.V, the new term would weight the new data with the old data's weights. SemWLS therefore overrides replace_observed to recompute the weights from the new data by default, while exposing a recompute_observed_state keyword to opt out:

# src/loss/WLS/WLS.jl
function replace_observed(
    loss::SemWLS,
    new_observed::SemObserved;
    recompute_observed_state::Bool = true,
)
    return SemWLS(
        new_observed,
        implied(loss),
        loss;
        # pass `nothing` to recompute from the new data, or reuse the old matrices
        wls_weight_matrix = recompute_observed_state ? nothing : loss.V,
        wls_weight_matrix_mean = recompute_observed_state ? nothing : loss.V_μ,
    )
end

Note how the override still goes through the three-argument constructor with loss as refloss, so all other configuration (e.g. the choice of approximate vs. analytic Hessian) is still inherited automatically — the custom method only intervenes for the observed-dependent caches.

The recompute_observed_state keyword is a convention shared by all replace_observed methods: it is forwarded from the model-level call down to every loss term, and terms without observed-dependent caches simply ignore it.