Spaces:
Sleeping
Sleeping
Upload __init__.py
Browse files- trellis/modules/sparse/__init__.py +102 -102
trellis/modules/sparse/__init__.py
CHANGED
|
@@ -1,102 +1,102 @@
|
|
| 1 |
-
from typing import *
|
| 2 |
-
|
| 3 |
-
BACKEND = 'spconv'
|
| 4 |
-
DEBUG = False
|
| 5 |
-
ATTN = 'flash_attn'
|
| 6 |
-
|
| 7 |
-
def __from_env():
|
| 8 |
-
import os
|
| 9 |
-
|
| 10 |
-
global BACKEND
|
| 11 |
-
global DEBUG
|
| 12 |
-
global ATTN
|
| 13 |
-
|
| 14 |
-
env_sparse_backend = os.environ.get('SPARSE_BACKEND')
|
| 15 |
-
env_sparse_debug = os.environ.get('SPARSE_DEBUG')
|
| 16 |
-
env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
|
| 17 |
-
if env_sparse_attn is None:
|
| 18 |
-
env_sparse_attn = os.environ.get('ATTN_BACKEND')
|
| 19 |
-
|
| 20 |
-
if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
|
| 21 |
-
BACKEND = env_sparse_backend
|
| 22 |
-
if env_sparse_debug is not None:
|
| 23 |
-
DEBUG = env_sparse_debug == '1'
|
| 24 |
-
if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
|
| 25 |
-
ATTN = env_sparse_attn
|
| 26 |
-
|
| 27 |
-
print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
__from_env()
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def set_backend(backend: Literal['spconv', 'torchsparse']):
|
| 34 |
-
global BACKEND
|
| 35 |
-
BACKEND = backend
|
| 36 |
-
|
| 37 |
-
def set_debug(debug: bool):
|
| 38 |
-
global DEBUG
|
| 39 |
-
DEBUG = debug
|
| 40 |
-
|
| 41 |
-
def set_attn(attn: Literal['xformers', 'flash_attn']):
|
| 42 |
-
global ATTN
|
| 43 |
-
ATTN = attn
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
import importlib
|
| 47 |
-
|
| 48 |
-
__attributes = {
|
| 49 |
-
'SparseTensor': 'basic',
|
| 50 |
-
'sparse_batch_broadcast': 'basic',
|
| 51 |
-
'sparse_batch_op': 'basic',
|
| 52 |
-
'sparse_cat': 'basic',
|
| 53 |
-
'sparse_unbind': 'basic',
|
| 54 |
-
'SparseGroupNorm': 'norm',
|
| 55 |
-
'SparseLayerNorm': 'norm',
|
| 56 |
-
'SparseGroupNorm32': 'norm',
|
| 57 |
-
'SparseLayerNorm32': 'norm',
|
| 58 |
-
'SparseReLU': 'nonlinearity',
|
| 59 |
-
'SparseSiLU': 'nonlinearity',
|
| 60 |
-
'SparseGELU': 'nonlinearity',
|
| 61 |
-
'SparseActivation': 'nonlinearity',
|
| 62 |
-
'SparseLinear': 'linear',
|
| 63 |
-
'sparse_scaled_dot_product_attention': 'attention',
|
| 64 |
-
'SerializeMode': 'attention',
|
| 65 |
-
'sparse_serialized_scaled_dot_product_self_attention': 'attention',
|
| 66 |
-
'sparse_windowed_scaled_dot_product_self_attention': 'attention',
|
| 67 |
-
'SparseMultiHeadAttention': 'attention',
|
| 68 |
-
'SparseConv3d': 'conv',
|
| 69 |
-
'SparseInverseConv3d': 'conv',
|
| 70 |
-
'SparseDownsample': 'spatial',
|
| 71 |
-
'SparseUpsample': 'spatial',
|
| 72 |
-
'SparseSubdivide' : 'spatial'
|
| 73 |
-
}
|
| 74 |
-
|
| 75 |
-
__submodules = ['transformer']
|
| 76 |
-
|
| 77 |
-
__all__ = list(__attributes.keys()) + __submodules
|
| 78 |
-
|
| 79 |
-
def __getattr__(name):
|
| 80 |
-
if name not in globals():
|
| 81 |
-
if name in __attributes:
|
| 82 |
-
module_name = __attributes[name]
|
| 83 |
-
module = importlib.import_module(f".{module_name}", __name__)
|
| 84 |
-
globals()[name] = getattr(module, name)
|
| 85 |
-
elif name in __submodules:
|
| 86 |
-
module = importlib.import_module(f".{name}", __name__)
|
| 87 |
-
globals()[name] = module
|
| 88 |
-
else:
|
| 89 |
-
raise AttributeError(f"module {__name__} has no attribute {name}")
|
| 90 |
-
return globals()[name]
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
# For Pylance
|
| 94 |
-
if __name__ == '__main__':
|
| 95 |
-
from .basic import *
|
| 96 |
-
from .norm import *
|
| 97 |
-
from .nonlinearity import *
|
| 98 |
-
from .linear import *
|
| 99 |
-
from .attention import *
|
| 100 |
-
from .conv import *
|
| 101 |
-
from .spatial import *
|
| 102 |
-
import transformer
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
|
| 3 |
+
BACKEND = 'spconv'
|
| 4 |
+
DEBUG = False
|
| 5 |
+
ATTN = 'flash_attn'
|
| 6 |
+
|
| 7 |
+
def __from_env():
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
global BACKEND
|
| 11 |
+
global DEBUG
|
| 12 |
+
global ATTN
|
| 13 |
+
|
| 14 |
+
env_sparse_backend = os.environ.get('SPARSE_BACKEND')
|
| 15 |
+
env_sparse_debug = os.environ.get('SPARSE_DEBUG')
|
| 16 |
+
env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
|
| 17 |
+
if env_sparse_attn is None:
|
| 18 |
+
env_sparse_attn = os.environ.get('ATTN_BACKEND')
|
| 19 |
+
|
| 20 |
+
if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
|
| 21 |
+
BACKEND = env_sparse_backend
|
| 22 |
+
if env_sparse_debug is not None:
|
| 23 |
+
DEBUG = env_sparse_debug == '1'
|
| 24 |
+
if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn', 'sdpa', 'naive']:
|
| 25 |
+
ATTN = env_sparse_attn
|
| 26 |
+
|
| 27 |
+
print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
__from_env()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def set_backend(backend: Literal['spconv', 'torchsparse']):
|
| 34 |
+
global BACKEND
|
| 35 |
+
BACKEND = backend
|
| 36 |
+
|
| 37 |
+
def set_debug(debug: bool):
|
| 38 |
+
global DEBUG
|
| 39 |
+
DEBUG = debug
|
| 40 |
+
|
| 41 |
+
def set_attn(attn: Literal['xformers', 'flash_attn', 'sdpa', 'naive']):
|
| 42 |
+
global ATTN
|
| 43 |
+
ATTN = attn
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
import importlib
|
| 47 |
+
|
| 48 |
+
__attributes = {
|
| 49 |
+
'SparseTensor': 'basic',
|
| 50 |
+
'sparse_batch_broadcast': 'basic',
|
| 51 |
+
'sparse_batch_op': 'basic',
|
| 52 |
+
'sparse_cat': 'basic',
|
| 53 |
+
'sparse_unbind': 'basic',
|
| 54 |
+
'SparseGroupNorm': 'norm',
|
| 55 |
+
'SparseLayerNorm': 'norm',
|
| 56 |
+
'SparseGroupNorm32': 'norm',
|
| 57 |
+
'SparseLayerNorm32': 'norm',
|
| 58 |
+
'SparseReLU': 'nonlinearity',
|
| 59 |
+
'SparseSiLU': 'nonlinearity',
|
| 60 |
+
'SparseGELU': 'nonlinearity',
|
| 61 |
+
'SparseActivation': 'nonlinearity',
|
| 62 |
+
'SparseLinear': 'linear',
|
| 63 |
+
'sparse_scaled_dot_product_attention': 'attention',
|
| 64 |
+
'SerializeMode': 'attention',
|
| 65 |
+
'sparse_serialized_scaled_dot_product_self_attention': 'attention',
|
| 66 |
+
'sparse_windowed_scaled_dot_product_self_attention': 'attention',
|
| 67 |
+
'SparseMultiHeadAttention': 'attention',
|
| 68 |
+
'SparseConv3d': 'conv',
|
| 69 |
+
'SparseInverseConv3d': 'conv',
|
| 70 |
+
'SparseDownsample': 'spatial',
|
| 71 |
+
'SparseUpsample': 'spatial',
|
| 72 |
+
'SparseSubdivide' : 'spatial'
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
__submodules = ['transformer']
|
| 76 |
+
|
| 77 |
+
__all__ = list(__attributes.keys()) + __submodules
|
| 78 |
+
|
| 79 |
+
def __getattr__(name):
|
| 80 |
+
if name not in globals():
|
| 81 |
+
if name in __attributes:
|
| 82 |
+
module_name = __attributes[name]
|
| 83 |
+
module = importlib.import_module(f".{module_name}", __name__)
|
| 84 |
+
globals()[name] = getattr(module, name)
|
| 85 |
+
elif name in __submodules:
|
| 86 |
+
module = importlib.import_module(f".{name}", __name__)
|
| 87 |
+
globals()[name] = module
|
| 88 |
+
else:
|
| 89 |
+
raise AttributeError(f"module {__name__} has no attribute {name}")
|
| 90 |
+
return globals()[name]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# For Pylance
|
| 94 |
+
if __name__ == '__main__':
|
| 95 |
+
from .basic import *
|
| 96 |
+
from .norm import *
|
| 97 |
+
from .nonlinearity import *
|
| 98 |
+
from .linear import *
|
| 99 |
+
from .attention import *
|
| 100 |
+
from .conv import *
|
| 101 |
+
from .spatial import *
|
| 102 |
+
import transformer
|