Skip to content

Variables

Variables in LUME-torch define the inputs and outputs of models. They provide validation, type checking, and metadata about model parameters.

Overview

LUME-torch uses variables to:

  • Define input and output specifications for models
  • Validate values at runtime
  • Store metadata like units, ranges, and default values
  • Support configuration file serialization

Variable Types

Base Variable Classes

These are imported from the lume-base package:

Variable - Abstract base class for all variables - Attributes: name, read_only, default_validation_config - Methods: validate_value(), model_dump()

ScalarVariable - Variable for scalar floating-point values - Attributes: name, default_value, read_only, value_range, unit - Methods: validate_value(), validates type and range

ConfigEnum - Validation configuration options - "none": No validation - "warn": Emit warnings for validation failures - "error": Raise errors for validation failures

For complete API documentation of these classes, see the lume-base documentation.

LUME-torch Specific Variables

lume_torch.variables.DistributionVariable

Bases: Variable

Variable for distributions. Must be a subclass of torch.distributions.Distribution.

Attributes

unit : str, optional Unit associated with the variable.

Source code in lume_torch/variables.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
class DistributionVariable(Variable):
    """Variable for distributions. Must be a subclass of torch.distributions.Distribution.

    Attributes
    ----------
    unit : str, optional
        Unit associated with the variable.

    """

    unit: Optional[str] = None

    @model_validator(mode="before")
    @classmethod
    def _compat_is_constant(cls, data: Any) -> Any:
        return _normalize_legacy_read_only(data)

    def validate_value(
        self, value: TDistribution, config: Optional[ConfigEnum] = None, **kwargs
    ):
        """Validates the given value.

        Parameters
        ----------
        value : Distribution
            The value to be validated.
        config : ConfigEnum, optional
            The configuration for validation. Defaults to None.
            Allowed values are "none", "warn", and "error".

        Raises
        ------
        TypeError
            If the value is not an instance of Distribution.

        """
        # mandatory validation
        self._validate_value_type(value)

        # optional validation
        config = self._validation_config_as_enum(config)
        if config != ConfigEnum.NULL:
            pass  # not implemented

    @staticmethod
    def _validate_value_type(value: TDistribution):
        if not isinstance(value, TDistribution):
            raise TypeError(
                f"Expected value to be of type {TDistribution}, "
                f"but received {type(value)}."
            )

validate_value(value, config=None, **kwargs)

Validates the given value.

Parameters

value : Distribution The value to be validated. config : ConfigEnum, optional The configuration for validation. Defaults to None. Allowed values are "none", "warn", and "error".

Raises

TypeError If the value is not an instance of Distribution.

Source code in lume_torch/variables.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def validate_value(
    self, value: TDistribution, config: Optional[ConfigEnum] = None, **kwargs
):
    """Validates the given value.

    Parameters
    ----------
    value : Distribution
        The value to be validated.
    config : ConfigEnum, optional
        The configuration for validation. Defaults to None.
        Allowed values are "none", "warn", and "error".

    Raises
    ------
    TypeError
        If the value is not an instance of Distribution.

    """
    # mandatory validation
    self._validate_value_type(value)

    # optional validation
    config = self._validation_config_as_enum(config)
    if config != ConfigEnum.NULL:
        pass  # not implemented

Utilities

lume_torch.variables.get_variable(name)

Returns the Variable subclass with the given name.

Parameters

name : str Name of the Variable subclass.

Returns

Type[Variable] Variable subclass with the given name.

Source code in lume_torch/variables.py
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
def get_variable(name: str) -> Type[Variable]:
    """Returns the Variable subclass with the given name.

    Parameters
    ----------
    name : str
        Name of the Variable subclass.

    Returns
    -------
    Type[Variable]
        Variable subclass with the given name.

    """
    classes = [
        TorchScalarVariable,
        ScalarVariable,
        DistributionVariable,
        TorchNDVariable,
    ]
    class_lookup = {c.__name__: c for c in classes}
    if name not in class_lookup.keys():
        logger.error(
            f"Unknown variable type '{name}', valid names are {list(class_lookup.keys())}"
        )
        raise KeyError(
            f"No variable named {name}, valid names are {list(class_lookup.keys())}"
        )
    return class_lookup[name]

Usage Examples

Creating Scalar Variables

from lume_torch.variables import ScalarVariable

# Basic variable
var = ScalarVariable(name="temperature")

# Variable with range
var = ScalarVariable(
    name="pressure",
    default_value=1.0,
    value_range=[0.0, 10.0],
    unit="atm"
)

# Read-only variable
constant = ScalarVariable(
    name="speed_of_light",
    default_value=299792458.0,
    read_only=True,
    unit="m/s"
)

Creating Distribution Variables

from lume_torch.variables import DistributionVariable

# For probabilistic model outputs
dist_var = DistributionVariable(
    name="output_distribution",
    unit="GeV"
)

Validation

from lume_torch.variables import ScalarVariable, ConfigEnum

var = ScalarVariable(
    name="energy",
    value_range=[0.0, 100.0],
    default_validation_config=ConfigEnum.ERROR
)

# This will raise an error
try:
    var.validate_value(150.0)
except ValueError as e:
    print(f"Validation failed: {e}")

# Configure validation behavior
var.validate_value(50.0, config=ConfigEnum.WARN)  # Valid, no warning
var.validate_value(150.0, config=ConfigEnum.WARN)  # Warning but no error

Using with Models

from lume_torch.base import LUMETorch
from lume_torch.variables import ScalarVariable


class PhysicsModel(LUMETorch):
    def _evaluate(self, input_dict):
        return {"force": input_dict["mass"] * input_dict["acceleration"]}


input_vars = [
    ScalarVariable(name="mass", value_range=[0.1, 1000.0], unit="kg"),
    ScalarVariable(name="acceleration", value_range=[0.0, 100.0], unit="m/s^2"),
]
output_vars = [
    ScalarVariable(name="force", unit="N"),
]

model = PhysicsModel(
    input_variables=input_vars,
    output_variables=output_vars
)

See Also