EDEN / eden /__init__.py
Rybib's picture
Upload EDEN model and code
2f65125 verified
Raw
History Blame
1.34 kB
"""EDEN: Encoder Decoder Enhancement Network.
A from-scratch PyTorch encoder-decoder Transformer for text enhancement.
"""
import os
# MPS environment must be set before torch is imported by any submodule.
os.environ.setdefault("PYTORCH_MPS_HIGH_WATERMARK_RATIO", "0.88")
os.environ.setdefault("PYTORCH_MPS_LOW_WATERMARK_RATIO", "0.70")
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("OMP_NUM_THREADS", "4")
os.environ.setdefault("MKL_NUM_THREADS", "4")
import warnings
warnings.filterwarnings(
"ignore",
message="enable_nested_tensor is True.*norm_first was True",
category=UserWarning,
)
from .config import RECIPES, TrainConfig, apply_recipe, model_param_count
from .model import EdenTransformer, PositionalEncoding
from .data import load_prepared_pairs, load_tokenizer, read_pairs_file, train_tokenizer
from .engine import enhance_text, load_model_for_inference, main, train_loop
__version__ = "1.0.0"
__all__ = [
"TrainConfig",
"RECIPES",
"apply_recipe",
"model_param_count",
"EdenTransformer",
"PositionalEncoding",
"load_tokenizer",
"train_tokenizer",
"read_pairs_file",
"load_prepared_pairs",
"enhance_text",
"load_model_for_inference",
"train_loop",
"main",
]