[docs]classMLPBackbone(nn.Module):"""Simple MLP for tabular features."""def__init__(self,input_dim:int,hidden_dim:int=256,depth:int=2):super().__init__()layers=[]dim=input_dimfor_inrange(depth):layers.extend([nn.Linear(dim,hidden_dim),nn.ReLU()])dim=hidden_dimself.net=nn.Sequential(*layers)
[docs]classTemporalEncoder(nn.Module):"""GRU-based encoder for time-series signals."""def__init__(self,input_dim:int,hidden_dim:int=128,num_layers:int=1):super().__init__()self.rnn=nn.GRU(input_dim,hidden_dim,num_layers=num_layers,batch_first=True)