Source code for tsbootstrap.block_generator

from __future__ import annotations

import logging
import warnings
from typing import Optional

import numpy as np
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    PositiveInt,
    ValidationInfo,
    field_validator,
    model_validator,
)

from tsbootstrap.block_length_sampler import (
    MIN_BLOCK_LENGTH,
    BlockLengthSampler,
)
from tsbootstrap.utils.types import RngTypes
from tsbootstrap.utils.validate import validate_block_indices

# Initialize logger for this module using the module's name.
logger = logging.getLogger(__name__)


[docs] class BlockGenerator(BaseModel): """ Generates blocks of indices for time series bootstrapping. This class is responsible for creating blocks of indices that can be used to sample segments from a time series. It supports both overlapping and non-overlapping blocks and can optionally wrap around the end of the time series. Parameters ---------- input_length : PositiveInt The length of the input time series. Must be at least 3. block_length_sampler : BlockLengthSampler An instance of BlockLengthSampler to determine block lengths. wrap_around_flag : bool, optional If True, blocks can wrap around the end of the time series. Default is False. rng : RngTypes, optional Random number generator for sampling. Defaults to a new RNG instance. overlap_length : PositiveInt, optional The length of overlap between consecutive blocks. If not provided, defaults to half the average block length. min_block_length : PositiveInt, optional The minimum allowed block length. Defaults to `MIN_BLOCK_LENGTH`. Examples -------- >>> from tsbootstrap.block_length_sampler import BlockLengthSampler, DistributionTypes >>> sampler = BlockLengthSampler(avg_block_length=5, block_length_distribution=DistributionTypes.GAMMA) >>> generator = BlockGenerator(input_length=100, block_length_sampler=sampler, wrap_around_flag=True) >>> blocks = generator.generate_blocks(overlap_flag=True) >>> print(blocks) [array([...]), array([...]), ...] """ # Configuration for Pydantic 2 using ConfigDict model_config = ConfigDict( arbitrary_types_allowed=True, # Allows custom types like Generator validate_assignment=True, # Validates fields on assignment ) # Define class attributes with Pydantic Fields block_length_sampler: BlockLengthSampler = Field( ..., description="Sampler for determining block lengths." ) input_length: PositiveInt = Field( ..., ge=3, description="The length of the input time series." ) wrap_around_flag: bool = Field( default=False, description="Flag to allow wrap-around in block generation.", ) rng: RngTypes = Field( # type: ignore default_factory=lambda: np.random.default_rng(seed=None), description="Random number generator.", ) overlap_length: Optional[PositiveInt] = Field( default=0, description="Overlap length between blocks." ) min_block_length: Optional[PositiveInt] = Field( default=MIN_BLOCK_LENGTH, ge=MIN_BLOCK_LENGTH, description="Minimum block length.", )
[docs] @field_validator("block_length_sampler", mode="before") @classmethod def validate_block_length_sampler( cls, v: BlockLengthSampler, info: ValidationInfo ) -> BlockLengthSampler: """ Validate the block_length_sampler to ensure its average block length does not exceed the input_length. Parameters ---------- v : BlockLengthSampler The BlockLengthSampler instance to validate. info : ValidationInfo Provides context about the validation, including other fields. Returns ------- BlockLengthSampler The validated BlockLengthSampler instance. Raises ------ ValueError If the sampler's average block length exceeds the input_length. """ input_length: Optional[int] = info.data.get("input_length") if input_length is not None and v.avg_block_length > input_length: error_msg = ( f"'sampler.avg_block_length' ({v.avg_block_length}) must be <= " f"'input_length' ({input_length})." ) logger.error(error_msg) raise ValueError(error_msg) logger.debug(f"Block length sampler validated: {v}") return v
[docs] @field_validator("overlap_length", mode="before") @classmethod def validate_overlap_length( cls, v: Optional[int], info: ValidationInfo ) -> int: """ Validate and adjust the overlap_length parameter. If overlap_length is provided and is greater than or equal to input_length, it is adjusted to input_length - 1 with a warning. If overlap_length is not provided, it defaults to half of the average block length. Parameters ---------- v : Optional[int] The overlap length to validate. info : ValidationInfo Provides context about the validation, including other fields. Returns ------- int The validated and possibly adjusted overlap length. Raises ------ ValueError If input_length or block_length_sampler is not provided. """ input_length: Optional[int] = info.data.get("input_length") block_length_sampler: Optional[BlockLengthSampler] = info.data.get( "block_length_sampler" ) if input_length is None or block_length_sampler is None: error_msg = ( "'input_length' and 'block_length_sampler' must be provided." ) logger.error(error_msg) raise ValueError(error_msg) if v is None: # Default to half of the average block length if overlap_length is not provided default_overlap: int = block_length_sampler.avg_block_length // 2 logger.debug( f"No 'overlap_length' provided. Defaulting to {default_overlap}." ) return default_overlap if v >= input_length: # Adjust overlap_length to input_length - 1 if it exceeds or equals input_length adjusted_overlap = input_length - 1 warnings.warn( f"'overlap_length' ({v}) should be < 'input_length' ({input_length}). " f"Setting to {adjusted_overlap}.", stacklevel=2, ) logger.warning( f"'overlap_length' ({v}) exceeds or equals 'input_length' ({input_length}). " f"Adjusted to {adjusted_overlap}." ) return adjusted_overlap logger.debug(f"'overlap_length' validated: {v}") return v
[docs] @field_validator("min_block_length", mode="before") @classmethod def validate_min_block_length( cls, v: Optional[int], info: ValidationInfo ) -> int: """ Validate and adjust the min_block_length parameter. If min_block_length is provided, it must be between MIN_BLOCK_LENGTH and the sampler's average block length. Otherwise, it defaults to MIN_BLOCK_LENGTH. Parameters ---------- v : Optional[int] The minimum block length to validate. info : ValidationInfo Provides context about the validation, including other fields. Returns ------- int The validated and possibly adjusted minimum block length. Raises ------ ValueError If block_length_sampler is not provided. """ block_length_sampler: Optional[BlockLengthSampler] = info.data.get( "block_length_sampler" ) if block_length_sampler is None: error_msg = "'block_length_sampler' must be provided." logger.error(error_msg) raise ValueError(error_msg) if v is None: # Default to MIN_BLOCK_LENGTH if min_block_length is not provided logger.debug( f"No 'min_block_length' provided. Defaulting to {MIN_BLOCK_LENGTH}." ) return MIN_BLOCK_LENGTH if v > block_length_sampler.avg_block_length: # Adjust min_block_length to sampler's avg_block_length if it exceeds adjusted_min_block_length = block_length_sampler.avg_block_length warnings.warn( f"'min_block_length' ({v}) should be <= 'avg_block_length' " f"({block_length_sampler.avg_block_length}). " f"Setting to {adjusted_min_block_length}.", stacklevel=2, ) logger.warning( f"'min_block_length' ({v}) exceeds 'avg_block_length' " f"({block_length_sampler.avg_block_length}). " f"Adjusted to {adjusted_min_block_length}." ) return adjusted_min_block_length logger.debug(f"'min_block_length' validated: {v}") return v
[docs] @model_validator(mode="after") def check_consistency(self) -> BlockGenerator: """ Perform inter-field validation to ensure consistency among fields. This validator runs after all field validators have processed their respective fields, ensuring that interdependent fields maintain logical consistency. Returns ------- BlockGenerator The validated BlockGenerator instance. Raises ------ ValueError If any of the consistency checks fail. """ input_length: int = self.input_length sampler: BlockLengthSampler = self.block_length_sampler overlap_length: int = self.overlap_length # type: ignore min_block_length: int = self.min_block_length # type: ignore if sampler.avg_block_length > input_length: error_msg = ( f"'sampler.avg_block_length' ({sampler.avg_block_length}) must be <= " f"'input_length' ({input_length})." ) logger.error(error_msg) raise ValueError(error_msg) if overlap_length >= input_length: error_msg = ( f"'overlap_length' ({overlap_length}) must be less than " f"'input_length' ({input_length})." ) logger.error(error_msg) raise ValueError(error_msg) if min_block_length > sampler.avg_block_length: error_msg = ( f"'min_block_length' ({min_block_length}) cannot exceed " f"'avg_block_length' ({sampler.avg_block_length})." ) logger.error(error_msg) raise ValueError(error_msg) logger.debug("All inter-field consistency checks passed.") return self
def _create_block(self, start_index: int, block_length: int) -> np.ndarray: """ Create a block of indices. Parameters ---------- start_index : int Starting index of the block. block_length : int Length of the block. Returns ------- np.ndarray An array representing the indices of a block in the time series. """ end_index = start_index + block_length if end_index <= self.input_length: # Block does not exceed input_length; no wrapping needed logger.debug( f"Creating block from {start_index} to {end_index} without wrapping." ) return np.arange(start_index, end_index) else: if self.wrap_around_flag: # Wrap around to the beginning end_index_wrapped = end_index % self.input_length logger.debug( f"Creating block with wrap-around: {start_index} to {self.input_length} and 0 to {end_index_wrapped}." ) return np.concatenate( ( np.arange(start_index, self.input_length), np.arange(0, end_index_wrapped), ) ) else: # Adjust block_length to fit within input_length without wrapping adjusted_end_index = self.input_length logger.debug( f"Creating block from {start_index} to {adjusted_end_index} without wrapping." ) return np.arange(start_index, adjusted_end_index) def _calculate_start_index(self) -> int: """ Calculate the starting index of a block. Returns ------- int The starting index of the block. """ if self.wrap_around_flag: start = self.rng.integers(self.input_length) # type: ignore logger.debug( f"Wrap-around enabled. Randomly selected starting index: {start}." ) return start else: logger.debug("Wrap-around disabled. Starting index set to 0.") return 0 def _calculate_overlap_length(self, sampled_block_length: int) -> int: """ Calculate the overlap length for a block. Parameters ---------- sampled_block_length : int The length of the sampled block. Returns ------- int The calculated overlap length. """ if self.overlap_length is None: # Default to half of the sampled block length if overlap_length is not provided calculated_overlap = sampled_block_length // 2 logger.debug( f"'overlap_length' not provided. Calculated default overlap: {calculated_overlap}." ) return calculated_overlap else: # Ensure overlap_length is at least 1 and less than block_length calculated_overlap = min( # type: ignore max(self.overlap_length, 1), sampled_block_length - 1, ) logger.debug(f"Calculated overlap length: {calculated_overlap}.") return calculated_overlap def _get_total_length_covered( self, block_length: int, overlap_length: int ) -> int: """ Get the total length covered in the time series considering the current block length and overlap length. Parameters ---------- block_length : int The length of the current block. overlap_length : int The overlap length between the current and next block. Returns ------- int The total length covered so far. """ total_length = block_length - overlap_length logger.debug(f"Total length covered increased by {total_length}.") return total_length def _get_next_block_length( self, sampled_block_length: int, total_length_covered: int ) -> int: """ Get the next block length after considering wrap-around and total length covered. Parameters ---------- sampled_block_length : int The sampled block length from the block length sampler. total_length_covered : int The total length covered so far. Returns ------- int The adjusted block length. """ if not self.wrap_around_flag: # Ensure block_length does not exceed remaining input length remaining_length = self.input_length - total_length_covered # type: ignore adjusted_block_length = min(sampled_block_length, remaining_length) logger.debug( f"Wrap-around disabled. Adjusted block length from {sampled_block_length} to {adjusted_block_length} based on remaining length {remaining_length}." ) return adjusted_block_length else: # Wrap-around is allowed; use sampled block length as is logger.debug( f"Wrap-around enabled. Using sampled block length: {sampled_block_length}." ) return sampled_block_length def _calculate_next_start_index( self, start_index: int, block_length: int, overlap_length: int, ) -> int: """ Calculate the next start index for generating the subsequent block. Parameters ---------- start_index : int The start index of the current block. block_length : int The length of the current block. overlap_length : int The overlap length between the current and next block. Returns ------- int The start index for the next block. """ next_start_index = start_index + block_length - overlap_length next_start_index %= self.input_length logger.debug(f"Next start index calculated: {next_start_index}.") return next_start_index
[docs] def generate_non_overlapping_blocks(self) -> list[np.ndarray]: """ Generate non-overlapping block indices in the time series. Returns ------- list[np.ndarray] list of numpy arrays containing the indices for each non-overlapping block. Raises ------ ValueError If the block length sampler is not set. """ logger.info("Generating non-overlapping blocks.") block_indices = [] start_index = self._calculate_start_index() total_length = 0 while total_length < self.input_length: # type: ignore sampled_block_length = ( self.block_length_sampler.sample_block_length() ) logger.debug(f"Sampled block length: {sampled_block_length}.") block_length = self._get_next_block_length( sampled_block_length, total_length ) logger.debug(f"Adjusted block length: {block_length}.") block = self._create_block(start_index, block_length) logger.debug(f"Generated block: {block}.") block_indices.append(block) total_length += block_length logger.debug(f"Total length covered so far: {total_length}.") start_index = self._calculate_next_start_index( start_index, block_length, overlap_length=0 ) validate_block_indices(block_indices, self.input_length) # type: ignore logger.info("Non-overlapping block generation completed successfully.") return block_indices
[docs] def generate_overlapping_blocks(self) -> list[np.ndarray]: """ Generate overlapping block indices in the time series. Returns ------- list[np.ndarray] A list of numpy arrays where each array represents the indices of a block in the time series. Notes ----- The block indices are generated as follows: 1. A starting index is determined based on the `wrap_around_flag`. 2. A block length is sampled from the `block_length_sampler`. 3. An overlap length is calculated from the block length. 4. A block is created from the starting index and block length. 5. The starting index is updated to the next starting index. 6. Steps 2-5 are repeated until the total length covered equals the length of the time series. The block length sampler is used to sample the block length. The overlap length is calculated based on the sampled block length. The starting index is updated by adding the block length and subtracting the overlap length, then taking modulo `input_length` to ensure it wraps around if necessary. """ logger.info("Generating overlapping blocks.") block_indices = [] start_index = self._calculate_start_index() total_length_covered = 0 start_indices = set() while total_length_covered < self.input_length: # type: ignore if start_index in start_indices: # Detect repeating start index to prevent infinite loops warnings.warn( "Detected repeating start index. Terminating block generation to prevent infinite loop.", stacklevel=2, ) logger.warning( "Repeating start index detected. Terminating block generation." ) break start_indices.add(start_index) sampled_block_length = ( self.block_length_sampler.sample_block_length() ) logger.debug(f"Sampled block length: {sampled_block_length}.") block_length = self._get_next_block_length( sampled_block_length, total_length_covered ) logger.debug(f"Adjusted block length: {block_length}.") if block_length < self.min_block_length: # type:ignore # Terminate if the block length is below the minimum required logger.debug( f"Block length {block_length} is less than minimum {self.min_block_length}. Terminating." ) break overlap_length = self._calculate_overlap_length(block_length) logger.debug(f"Calculated overlap length: {overlap_length}.") block = self._create_block(start_index, block_length) logger.debug(f"Generated block: {block}.") block_indices.append(block) total_length_covered += self._get_total_length_covered( block_length, overlap_length ) logger.debug( f"Total length covered so far: {total_length_covered}." ) start_index = self._calculate_next_start_index( start_index, block_length, overlap_length ) # Prevent exceeding the maximum allowed iterations if total_length_covered >= self.input_length: logger.info("Overlapping block generation completed successfully.") else: warnings.warn( "Maximum iterations reached during block generation. Some input may remain uncovered.", stacklevel=2, ) logger.warning( "Maximum iterations reached. Some input may remain uncovered." ) validate_block_indices(block_indices, self.input_length) # type: ignore return block_indices
[docs] def generate_blocks(self, overlap_flag: bool = False) -> list[np.ndarray]: """ Generate block indices based on the overlap flag. This method serves as a general entry point to generate either overlapping or non-overlapping blocks based on the provided flag. Parameters ---------- overlap_flag : bool, optional If True, generate overlapping blocks. Otherwise, generate non-overlapping blocks. Default is False. Returns ------- list[np.ndarray] list of numpy arrays representing the generated blocks. """ if overlap_flag: return self.generate_overlapping_blocks() else: return self.generate_non_overlapping_blocks()