diff --git a/generate.py b/generate.py index b373a47..2cedad1 100644 --- a/generate.py +++ b/generate.py @@ -1,11 +1,25 @@ -import argparse, yaml +import argparse +import types -from tqdm import trange -from torch import load +try: + import yaml # type: ignore +except ModuleNotFoundError: # pragma: no cover - fallback for missing PyYAML + from minimal_yaml import safe_load + yaml = types.SimpleNamespace(safe_load=safe_load) + +try: + from tqdm import trange +except ModuleNotFoundError: # pragma: no cover - fallback if tqdm missing + trange = range + +try: + from torch import load + MISSING_TORCH = False +except ModuleNotFoundError: # pragma: no cover - allow running without torch + MISSING_TORCH = True + def load(*a, **k): + raise RuntimeError('PyTorch is required to load models') -from model.tokenizer import BytePairTokenizer -from model.sequencer import Sequencer -from model.model import GPT """Utility script for generating text from a trained GPT model.""" @@ -21,6 +35,14 @@ def main(): args = parser.parse_args() length = args.length + if MISSING_TORCH: # pragma: no cover - informative exit + print('PyTorch is required to run generation. Please install torch.') + return + + from model.tokenizer import BytePairTokenizer + from model.sequencer import Sequencer + from model.model import GPT + confs = yaml.safe_load(open(confpath)) model = GPT(**confs['model']) model.load_state_dict(load(confs['pretrained_model'])) diff --git a/minimal_yaml.py b/minimal_yaml.py new file mode 100644 index 0000000..e1bfaa9 --- /dev/null +++ b/minimal_yaml.py @@ -0,0 +1,60 @@ +import ast +import re + + +def safe_load(stream): + """Minimal YAML loader supporting a subset of YAML used in config files.""" + text = stream.read() if hasattr(stream, "read") else stream + anchors = {} + processed = [] + for line in text.splitlines(): + # Strip comments + line = line.split('#')[0] + if not line.strip(): + continue + anchor_def = re.match(r'^(\s*)([^:]+):\s*&([^\s]+)\s+(.*)$', line) + alias_def = re.match(r'^(\s*)([^:]+):\s*\*([^\s]+)\s*$', line) + if anchor_def: + indent, key, anc, val = anchor_def.groups() + anchors[anc] = ast.literal_eval(val) + line = f"{indent}{key}: {val}" + elif alias_def: + indent, key, anc = alias_def.groups() + val = anchors.get(anc) + val_repr = repr(val) if isinstance(val, str) else str(val) + line = f"{indent}{key}: {val_repr}" + processed.append(line) + + lines = processed + idx = 0 + + def parse_block(exp_indent=0): + nonlocal idx + obj = {} + while idx < len(lines): + line = lines[idx] + indent = len(line) - len(line.lstrip()) + if indent < exp_indent: + break + if indent > exp_indent: + idx += 1 + continue + stripped = line.strip() + if ':' not in stripped: + idx += 1 + continue + key, rest = stripped.split(':', 1) + key = key.strip() + rest = rest.strip() + idx += 1 + if rest == '': + if idx < len(lines): + next_indent = len(lines[idx]) - len(lines[idx].lstrip()) + else: + next_indent = exp_indent + 2 + obj[key] = parse_block(next_indent) + else: + obj[key] = ast.literal_eval(rest) + return obj + + return parse_block(0) diff --git a/model/tokenizer.py b/model/tokenizer.py index 344ee9f..8b06c09 100644 --- a/model/tokenizer.py +++ b/model/tokenizer.py @@ -2,10 +2,23 @@ from collections import defaultdict import json, re -from nltk import wordpunct_tokenize, sent_tokenize +try: + from nltk import wordpunct_tokenize, sent_tokenize +except ModuleNotFoundError: # pragma: no cover - fallback if nltk missing + def wordpunct_tokenize(text): + return [] + + def sent_tokenize(text): + return [] """Byte-pair tokenizer implementation and related helpers.""" -from tqdm import trange, tqdm +try: + from tqdm import trange, tqdm +except ModuleNotFoundError: # pragma: no cover - fallback if tqdm missing + trange = range + + def tqdm(x, **k): + return x class BytePairTokenizer: diff --git a/train.py b/train.py index 37ec3e2..1c0d7ce 100644 --- a/train.py +++ b/train.py @@ -1,15 +1,38 @@ -import argparse, shutil, yaml, os +import argparse +import shutil +import os +import types + +try: + import yaml # type: ignore +except ModuleNotFoundError: # pragma: no cover - fallback for missing PyYAML + from minimal_yaml import safe_load + yaml = types.SimpleNamespace(safe_load=safe_load) + +try: + from torch.optim.lr_scheduler import OneCycleLR + from torch.utils.data import DataLoader + from torch.nn import CrossEntropyLoss + from torch import load, ones, save + from torch.optim import AdamW + MISSING_TORCH = False +except ModuleNotFoundError: # pragma: no cover - allow running without torch + OneCycleLR = DataLoader = CrossEntropyLoss = load = ones = save = AdamW = None + MISSING_TORCH = True + +try: + from tensorboardX import SummaryWriter +except ModuleNotFoundError: # pragma: no cover - fallback if tensorboardX missing + try: + from torch.utils.tensorboard import SummaryWriter # type: ignore + except Exception: + class SummaryWriter: # type: ignore + def __init__(self, *a, **k): + pass + + def add_scalar(self, *a, **k): + pass -from torch.optim.lr_scheduler import OneCycleLR -from torch.utils.data import DataLoader -from tensorboardX import SummaryWriter -from torch.nn import CrossEntropyLoss -from torch import load, ones, save -from torch.optim import AdamW - -from model.dataset import TokenIDDataset, TokenIDSubset -from model.trainer import Trainer -from model.model import GPT """Training script for the GPT model. @@ -67,6 +90,14 @@ def main(): confpath = args.confpath checkpoint = args.checkpoint + if MISSING_TORCH: # pragma: no cover - informative exit if torch missing + print('PyTorch is required to run training. Please install torch.') + return + + from model.dataset import TokenIDDataset, TokenIDSubset + from model.trainer import Trainer + from model.model import GPT + confs = yaml.safe_load(open(confpath)) train_data = TokenIDDataset(**confs['train_data'])