Skip to content

Variables

Variables in LUME-torch define the inputs and outputs of models. They provide validation, type checking, and metadata about model parameters.

Overview

LUME-torch uses variables to:

  • Define input and output specifications for models
  • Validate values at runtime
  • Store metadata like units, ranges, and default values
  • Support configuration file serialization

Variable Types

Base Variable Classes

These are imported from the lume-base package:

Variable - Abstract base class for all variables - Attributes: name, read_only, default_validation_config - Methods: validate_value(), model_dump()

ScalarVariable - Variable for scalar floating-point values - Attributes: name, default_value, read_only, value_range, unit - Methods: validate_value(), validates type and range

ConfigEnum - Validation configuration options - "none": No validation - "warn": Emit warnings for validation failures - "error": Raise errors for validation failures

For complete API documentation of these classes, see the lume-base documentation.

LUME-torch Specific Variables

lume_torch.variables.DistributionVariable

Bases: Variable

Variable for distributions. Must be a subclass of torch.distributions.Distribution.

Attributes

unit : str, optional Unit associated with the variable.

Source code in lume_torch/variables.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class DistributionVariable(Variable):
    """Variable for distributions. Must be a subclass of torch.distributions.Distribution.

    Attributes
    ----------
    unit : str, optional
        Unit associated with the variable.

    """

    unit: Optional[str] = None

    def validate_value(self, value: TDistribution, config: ConfigEnum = None):
        """Validates the given value.

        Parameters
        ----------
        value : Distribution
            The value to be validated.
        config : ConfigEnum, optional
            The configuration for validation. Defaults to None.
            Allowed values are "none", "warn", and "error".

        Raises
        ------
        TypeError
            If the value is not an instance of Distribution.

        """
        _config = self.default_validation_config if config is None else config
        # mandatory validation
        self._validate_value_type(value)
        # optional validation
        if config != "none":
            pass  # not implemented

    @staticmethod
    def _validate_value_type(value: TDistribution):
        if not isinstance(value, TDistribution):
            raise TypeError(
                f"Expected value to be of type {TDistribution}, "
                f"but received {type(value)}."
            )

validate_value(value, config=None)

Validates the given value.

Parameters

value : Distribution The value to be validated. config : ConfigEnum, optional The configuration for validation. Defaults to None. Allowed values are "none", "warn", and "error".

Raises

TypeError If the value is not an instance of Distribution.

Source code in lume_torch/variables.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def validate_value(self, value: TDistribution, config: ConfigEnum = None):
    """Validates the given value.

    Parameters
    ----------
    value : Distribution
        The value to be validated.
    config : ConfigEnum, optional
        The configuration for validation. Defaults to None.
        Allowed values are "none", "warn", and "error".

    Raises
    ------
    TypeError
        If the value is not an instance of Distribution.

    """
    _config = self.default_validation_config if config is None else config
    # mandatory validation
    self._validate_value_type(value)
    # optional validation
    if config != "none":
        pass  # not implemented

Utilities

lume_torch.variables.get_variable(name)

Returns the Variable subclass with the given name.

Parameters

name : str Name of the Variable subclass.

Returns

Type[Variable] Variable subclass with the given name.

Source code in lume_torch/variables.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def get_variable(name: str) -> Type[Variable]:
    """Returns the Variable subclass with the given name.

    Parameters
    ----------
    name : str
        Name of the Variable subclass.

    Returns
    -------
    Type[Variable]
        Variable subclass with the given name.

    """
    classes = [ScalarVariable, DistributionVariable]
    class_lookup = {c.__name__: c for c in classes}
    if name not in class_lookup.keys():
        logger.error(
            f"Unknown variable type '{name}', valid names are {list(class_lookup.keys())}"
        )
        raise KeyError(
            f"No variable named {name}, valid names are {list(class_lookup.keys())}"
        )
    return class_lookup[name]

Usage Examples

Creating Scalar Variables

from lume_torch.variables import ScalarVariable

# Basic variable
var = ScalarVariable(name="temperature")

# Variable with range
var = ScalarVariable(
    name="pressure",
    default_value=1.0,
    value_range=[0.0, 10.0],
    unit="atm"
)

# Read-only variable
constant = ScalarVariable(
    name="speed_of_light",
    default_value=299792458.0,
    read_only=True,
    unit="m/s"
)

Creating Distribution Variables

from lume_torch.variables import DistributionVariable

# For probabilistic model outputs
dist_var = DistributionVariable(
    name="output_distribution",
    unit="GeV"
)

Validation

from lume_torch.variables import ScalarVariable, ConfigEnum

var = ScalarVariable(
    name="energy",
    value_range=[0.0, 100.0],
    default_validation_config=ConfigEnum.ERROR
)

# This will raise an error
try:
    var.validate_value(150.0)
except ValueError as e:
    print(f"Validation failed: {e}")

# Configure validation behavior
var.validate_value(50.0, config=ConfigEnum.WARN)  # Valid, no warning
var.validate_value(150.0, config=ConfigEnum.WARN)  # Warning but no error

Using with Models

from lume_torch.base import LUMETorch
from lume_torch.variables import ScalarVariable


class PhysicsModel(LUMETorch):
    def _evaluate(self, input_dict):
        return {"force": input_dict["mass"] * input_dict["acceleration"]}


input_vars = [
    ScalarVariable(name="mass", value_range=[0.1, 1000.0], unit="kg"),
    ScalarVariable(name="acceleration", value_range=[0.0, 100.0], unit="m/s^2"),
]
output_vars = [
    ScalarVariable(name="force", unit="N"),
]

model = PhysicsModel(
    input_variables=input_vars,
    output_variables=output_vars
)

See Also