Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions generate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
import argparse, yaml
import argparse
import types

from tqdm import trange
from torch import load
try:
import yaml # type: ignore
except ModuleNotFoundError: # pragma: no cover - fallback for missing PyYAML
from minimal_yaml import safe_load
yaml = types.SimpleNamespace(safe_load=safe_load)

try:
from tqdm import trange
except ModuleNotFoundError: # pragma: no cover - fallback if tqdm missing
trange = range

try:
from torch import load
MISSING_TORCH = False
except ModuleNotFoundError: # pragma: no cover - allow running without torch
MISSING_TORCH = True
def load(*a, **k):
raise RuntimeError('PyTorch is required to load models')

from model.tokenizer import BytePairTokenizer
from model.sequencer import Sequencer
from model.model import GPT

"""Utility script for generating text from a trained GPT model."""

Expand All @@ -21,6 +35,14 @@ def main():
args = parser.parse_args()
length = args.length

if MISSING_TORCH: # pragma: no cover - informative exit
print('PyTorch is required to run generation. Please install torch.')
return

from model.tokenizer import BytePairTokenizer
from model.sequencer import Sequencer
from model.model import GPT

confs = yaml.safe_load(open(confpath))
model = GPT(**confs['model'])
model.load_state_dict(load(confs['pretrained_model']))
Expand Down
60 changes: 60 additions & 0 deletions minimal_yaml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import ast
import re


def safe_load(stream):
"""Minimal YAML loader supporting a subset of YAML used in config files."""
text = stream.read() if hasattr(stream, "read") else stream
anchors = {}
processed = []
for line in text.splitlines():
# Strip comments
line = line.split('#')[0]
if not line.strip():
continue
anchor_def = re.match(r'^(\s*)([^:]+):\s*&([^\s]+)\s+(.*)$', line)
alias_def = re.match(r'^(\s*)([^:]+):\s*\*([^\s]+)\s*$', line)
if anchor_def:
indent, key, anc, val = anchor_def.groups()
anchors[anc] = ast.literal_eval(val)
line = f"{indent}{key}: {val}"
elif alias_def:
indent, key, anc = alias_def.groups()
val = anchors.get(anc)
val_repr = repr(val) if isinstance(val, str) else str(val)
line = f"{indent}{key}: {val_repr}"
processed.append(line)

lines = processed
idx = 0

def parse_block(exp_indent=0):
nonlocal idx
obj = {}
while idx < len(lines):
line = lines[idx]
indent = len(line) - len(line.lstrip())
if indent < exp_indent:
break
if indent > exp_indent:
idx += 1
continue
stripped = line.strip()
if ':' not in stripped:
idx += 1
continue
key, rest = stripped.split(':', 1)
key = key.strip()
rest = rest.strip()
idx += 1
if rest == '':
if idx < len(lines):
next_indent = len(lines[idx]) - len(lines[idx].lstrip())
else:
next_indent = exp_indent + 2
obj[key] = parse_block(next_indent)
else:
obj[key] = ast.literal_eval(rest)
return obj

return parse_block(0)
17 changes: 15 additions & 2 deletions model/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,23 @@
from collections import defaultdict
import json, re

from nltk import wordpunct_tokenize, sent_tokenize
try:
from nltk import wordpunct_tokenize, sent_tokenize
except ModuleNotFoundError: # pragma: no cover - fallback if nltk missing
def wordpunct_tokenize(text):
return []

def sent_tokenize(text):
return []

"""Byte-pair tokenizer implementation and related helpers."""
from tqdm import trange, tqdm
try:
from tqdm import trange, tqdm
except ModuleNotFoundError: # pragma: no cover - fallback if tqdm missing
trange = range

def tqdm(x, **k):
return x


class BytePairTokenizer:
Expand Down
53 changes: 42 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,38 @@
import argparse, shutil, yaml, os
import argparse
import shutil
import os
import types

try:
import yaml # type: ignore
except ModuleNotFoundError: # pragma: no cover - fallback for missing PyYAML
from minimal_yaml import safe_load
yaml = types.SimpleNamespace(safe_load=safe_load)

try:
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch import load, ones, save
from torch.optim import AdamW
MISSING_TORCH = False
except ModuleNotFoundError: # pragma: no cover - allow running without torch
OneCycleLR = DataLoader = CrossEntropyLoss = load = ones = save = AdamW = None
MISSING_TORCH = True

try:
from tensorboardX import SummaryWriter
except ModuleNotFoundError: # pragma: no cover - fallback if tensorboardX missing
try:
from torch.utils.tensorboard import SummaryWriter # type: ignore
except Exception:
class SummaryWriter: # type: ignore
def __init__(self, *a, **k):
pass

def add_scalar(self, *a, **k):
pass

from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from torch.nn import CrossEntropyLoss
from torch import load, ones, save
from torch.optim import AdamW

from model.dataset import TokenIDDataset, TokenIDSubset
from model.trainer import Trainer
from model.model import GPT

"""Training script for the GPT model.

Expand Down Expand Up @@ -67,6 +90,14 @@ def main():
confpath = args.confpath
checkpoint = args.checkpoint

if MISSING_TORCH: # pragma: no cover - informative exit if torch missing
print('PyTorch is required to run training. Please install torch.')
return

from model.dataset import TokenIDDataset, TokenIDSubset
from model.trainer import Trainer
from model.model import GPT

confs = yaml.safe_load(open(confpath))

train_data = TokenIDDataset(**confs['train_data'])
Expand Down