Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions SYMBA_HEP/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Computing squared amplitudes of feynman diagrams from amplitudes is an $O(N^2)$
|---Vanilla Transformers + Original Data Generation
|---Longformer and BART + Engine
|---SKANFormer + Engine Updates + Updated Data Generation
|---
```


Empty file.
Empty file.
132 changes: 132 additions & 0 deletions SYMBA_HEP/SYMBAHEP_Hybrid_SSM_Prasanth_Naidu/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from dataclasses import dataclass, asdict
from typing import Optional,List

@dataclass
class ModelConfig:

# Project & Run Information
project_name: str
run_name: str
model_name: str

# Directories
root_dir: str
data_dir: str

# Hardware & Training Setup
device: str
epochs: int
training_batch_size: int
valid_batch_size: int
num_workers: int

# Model Architecture
embedding_size: int
nhead: int
num_encoder_layers: int
num_decoder_layers: int
ff_dims: int

# Optimization & Regularization
warmup_ratio: float
dropout: float
weight_decay: float
optimizer_lr: float
is_constant_lr: bool

# Sequence Configuration
src_max_len: int
tgt_max_len: int
is_termwise: bool

# Training Control
curr_epoch: int
train_shuffle: bool
valid_shuffle: bool
pin_memory: bool
world_size: int
resume_best: bool

# Optional Parameters
dtype: Optional[str] = 'bfloat16'
run_id: Optional[str] = None
backend: Optional[str] = 'nccl'
src_voc_size: Optional[int] = None
tgt_voc_size: Optional[int] = None
save_freq: Optional[int] = 3
save_limit: Optional[int] = 3
seed: Optional[int] = 42
update_lr: Optional[float] = None
end_lr: Optional[float] = 1e-6
clip_grad_norm: Optional[float] = -1
save_last: Optional[bool] = True
log_freq: Optional[int] = 50
test_freq: Optional[int] = 10
truncate: Optional[bool] = False
filter_len: Optional[bool] = False
debug: Optional[bool] = False
to_replace: bool = False
index_pool_size: int = 100

def to_dict(self):
"""Convert dataclass to dictionary."""
return asdict(self)


@dataclass
class ModelTestConfig:

# Model name
model_name: str

# Directory where data and model checkpoints will be stored
root_dir: str

data_dir: str
# Device for training (e.g., "cuda" for GPU, "cpu")
device: str

# Dimensionality of word embeddings
embedding_size: int

# Number of attention heads in the transformer model
nhead: int

# Number of encoder layers in the transformer model
num_encoder_layers : int
num_decoder_layers: int

# FFN dims
ff_dims: int

# Dropout rate
dropout: float

# Maximum length of source and target sequences
src_max_len: int
tgt_max_len: int
is_termwise: bool

# Size of vocabulary for source and target sequences
src_voc_size: Optional[int] = None
tgt_voc_size: Optional[int] = None

# Seed for reproducibility
seed: Optional[int] = 42

# trucate sequences
truncate: Optional[bool]= False

# if debug
debug: Optional[bool] = False

#to replace index and momentum
to_replace: bool = False

#token pool sizes
index_pool_size : int = 100

dtype: Optional[str] = 'bfloat16'

def to_dict(self):
return asdict(self)
13 changes: 13 additions & 0 deletions SYMBA_HEP/SYMBAHEP_Hybrid_SSM_Prasanth_Naidu/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Special token indices
BOS_IDX = 0 # Beginning of Sequence
PAD_IDX = 1 # Padding
EOS_IDX = 2 # End of Sequence
UNK_IDX = 3 # Unknown Token
SEP_IDX = 4 # Separator Token

T_IDX = [i for i in range(5,25)]


# Special token symbols
SPL_TERM_SYMBOLS = [f'<T{i}>' for i in range(20)]
SPECIAL_SYMBOLS = ['<BOS>', '<PAD>', '<EOS>', '<UNK>', '<SEP>'] + SPL_TERM_SYMBOLS
126 changes: 126 additions & 0 deletions SYMBA_HEP/SYMBAHEP_Hybrid_SSM_Prasanth_Naidu/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from torch.utils.data import Dataset
import torch

from .constants import BOS_IDX, PAD_IDX, EOS_IDX
from .logger import get_logger

logger = get_logger(__name__)
class Data(Dataset):
"""
Custom PyTorch dataset for handling data.

Args:
df (DataFrame): DataFrame containing data.
"""

def __init__(self, df, tokenizer, config, src_vocab, tgt_vocab):
super(Data, self).__init__()
self.config = config
self.tgt_tokenize = tokenizer.tgt_tokenize
self.src_tokenize = tokenizer.src_tokenize
self.bos_token = torch.tensor([BOS_IDX], dtype=torch.int64)
self.eos_token = torch.tensor([EOS_IDX], dtype=torch.int64)
self.pad_token = torch.tensor([PAD_IDX], dtype=torch.int64)
self.src_vocab = src_vocab
self.tgt_vocab = tgt_vocab

if self.config.filter_len:
df = df[
(df['sqamp'].str.len() <= self.config.tgt_max_len) &
(df['amp'].str.len() <= self.config.src_max_len)
].reset_index(drop=True)
logger.info(f"Filtered data size is: {len(df)}")

self.tgt_vals = df['sqamp']
self.src_vals = df['amp']



def __len__(self):
"""
Get the length of the dataset.

Returns:
int: Length of the dataset.
"""
return len(self.src_vals)

def __getitem__(self, idx):
"""
Get an item from the dataset at the specified index.

Args:
idx (int): Index of the item.

Returns:
tuple: Tuple containing source and target tensors.
"""
src_tokenized = self.src_tokenize(self.src_vals[idx])
tgt_tokenized = self.tgt_tokenize(self.tgt_vals[idx])
src_ids = self.src_vocab.encode(src_tokenized)
tgt_ids = self.tgt_vocab.encode(tgt_tokenized)

enc_excess_tokens = self.config.src_max_len - len(src_ids) - 3
dec_excess_tokens = self.config.tgt_max_len - len(tgt_ids) - 3

if self.config.truncate:
if enc_excess_tokens < 0:
src_ids = src_ids[:self.config.src_max_len-3]
if dec_excess_tokens < 0:
tgt_ids = tgt_ids[:self.config.tgt_max_len-3]
else:
if enc_excess_tokens < 0 or dec_excess_tokens < 0:
raise ValueError(f"Sentence is too long \n enc_excess_tokens: {enc_excess_tokens}, dec_excess_tokens: {dec_excess_tokens}")

if self.config.is_termwise:
src_tensor = torch.cat(
[
torch.tensor(src_ids, dtype=torch.int64),
self.pad_token,
],
dim=0,
)
tgt_tensor = torch.cat(
[
torch.tensor(tgt_ids, dtype=torch.int64),
self.pad_token,

],
dim=0,
)
else:
src_tensor = torch.cat(
[
self.bos_token,
torch.tensor(src_ids, dtype=torch.int64),
self.eos_token,
self.pad_token,
],
dim=0,
)
tgt_tensor = torch.cat(
[
self.bos_token,
torch.tensor(tgt_ids, dtype=torch.int64),
self.eos_token,
self.pad_token,

],
dim=0,
)

return src_tensor, tgt_tensor

@staticmethod
def get_data(df_train, df_test, df_valid, config, tokenizer, src_vocab,tgt_vocab):
"""
Create datasets (train, test, and valid)

Returns:
dict: Dictionary containing train, test, and valid datasets.
"""
train = Data(df_train, tokenizer, config,src_vocab,tgt_vocab)
test = Data(df_test, tokenizer, config,src_vocab,tgt_vocab) if df_test is not None else None
valid = Data(df_valid, tokenizer, config,src_vocab,tgt_vocab)

return {'train': train, 'test': test, 'valid': valid}
Loading