yonnel commited on
Commit
7cc59a0
·
verified ·
1 Parent(s): 5717970

Upload __init__.py

Browse files
Files changed (1) hide show
  1. trellis/modules/sparse/__init__.py +102 -102
trellis/modules/sparse/__init__.py CHANGED
@@ -1,102 +1,102 @@
1
- from typing import *
2
-
3
- BACKEND = 'spconv'
4
- DEBUG = False
5
- ATTN = 'flash_attn'
6
-
7
- def __from_env():
8
- import os
9
-
10
- global BACKEND
11
- global DEBUG
12
- global ATTN
13
-
14
- env_sparse_backend = os.environ.get('SPARSE_BACKEND')
15
- env_sparse_debug = os.environ.get('SPARSE_DEBUG')
16
- env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
17
- if env_sparse_attn is None:
18
- env_sparse_attn = os.environ.get('ATTN_BACKEND')
19
-
20
- if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
21
- BACKEND = env_sparse_backend
22
- if env_sparse_debug is not None:
23
- DEBUG = env_sparse_debug == '1'
24
- if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
25
- ATTN = env_sparse_attn
26
-
27
- print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
28
-
29
-
30
- __from_env()
31
-
32
-
33
- def set_backend(backend: Literal['spconv', 'torchsparse']):
34
- global BACKEND
35
- BACKEND = backend
36
-
37
- def set_debug(debug: bool):
38
- global DEBUG
39
- DEBUG = debug
40
-
41
- def set_attn(attn: Literal['xformers', 'flash_attn']):
42
- global ATTN
43
- ATTN = attn
44
-
45
-
46
- import importlib
47
-
48
- __attributes = {
49
- 'SparseTensor': 'basic',
50
- 'sparse_batch_broadcast': 'basic',
51
- 'sparse_batch_op': 'basic',
52
- 'sparse_cat': 'basic',
53
- 'sparse_unbind': 'basic',
54
- 'SparseGroupNorm': 'norm',
55
- 'SparseLayerNorm': 'norm',
56
- 'SparseGroupNorm32': 'norm',
57
- 'SparseLayerNorm32': 'norm',
58
- 'SparseReLU': 'nonlinearity',
59
- 'SparseSiLU': 'nonlinearity',
60
- 'SparseGELU': 'nonlinearity',
61
- 'SparseActivation': 'nonlinearity',
62
- 'SparseLinear': 'linear',
63
- 'sparse_scaled_dot_product_attention': 'attention',
64
- 'SerializeMode': 'attention',
65
- 'sparse_serialized_scaled_dot_product_self_attention': 'attention',
66
- 'sparse_windowed_scaled_dot_product_self_attention': 'attention',
67
- 'SparseMultiHeadAttention': 'attention',
68
- 'SparseConv3d': 'conv',
69
- 'SparseInverseConv3d': 'conv',
70
- 'SparseDownsample': 'spatial',
71
- 'SparseUpsample': 'spatial',
72
- 'SparseSubdivide' : 'spatial'
73
- }
74
-
75
- __submodules = ['transformer']
76
-
77
- __all__ = list(__attributes.keys()) + __submodules
78
-
79
- def __getattr__(name):
80
- if name not in globals():
81
- if name in __attributes:
82
- module_name = __attributes[name]
83
- module = importlib.import_module(f".{module_name}", __name__)
84
- globals()[name] = getattr(module, name)
85
- elif name in __submodules:
86
- module = importlib.import_module(f".{name}", __name__)
87
- globals()[name] = module
88
- else:
89
- raise AttributeError(f"module {__name__} has no attribute {name}")
90
- return globals()[name]
91
-
92
-
93
- # For Pylance
94
- if __name__ == '__main__':
95
- from .basic import *
96
- from .norm import *
97
- from .nonlinearity import *
98
- from .linear import *
99
- from .attention import *
100
- from .conv import *
101
- from .spatial import *
102
- import transformer
 
1
+ from typing import *
2
+
3
+ BACKEND = 'spconv'
4
+ DEBUG = False
5
+ ATTN = 'flash_attn'
6
+
7
+ def __from_env():
8
+ import os
9
+
10
+ global BACKEND
11
+ global DEBUG
12
+ global ATTN
13
+
14
+ env_sparse_backend = os.environ.get('SPARSE_BACKEND')
15
+ env_sparse_debug = os.environ.get('SPARSE_DEBUG')
16
+ env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
17
+ if env_sparse_attn is None:
18
+ env_sparse_attn = os.environ.get('ATTN_BACKEND')
19
+
20
+ if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
21
+ BACKEND = env_sparse_backend
22
+ if env_sparse_debug is not None:
23
+ DEBUG = env_sparse_debug == '1'
24
+ if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn', 'sdpa', 'naive']:
25
+ ATTN = env_sparse_attn
26
+
27
+ print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
28
+
29
+
30
+ __from_env()
31
+
32
+
33
+ def set_backend(backend: Literal['spconv', 'torchsparse']):
34
+ global BACKEND
35
+ BACKEND = backend
36
+
37
+ def set_debug(debug: bool):
38
+ global DEBUG
39
+ DEBUG = debug
40
+
41
+ def set_attn(attn: Literal['xformers', 'flash_attn', 'sdpa', 'naive']):
42
+ global ATTN
43
+ ATTN = attn
44
+
45
+
46
+ import importlib
47
+
48
+ __attributes = {
49
+ 'SparseTensor': 'basic',
50
+ 'sparse_batch_broadcast': 'basic',
51
+ 'sparse_batch_op': 'basic',
52
+ 'sparse_cat': 'basic',
53
+ 'sparse_unbind': 'basic',
54
+ 'SparseGroupNorm': 'norm',
55
+ 'SparseLayerNorm': 'norm',
56
+ 'SparseGroupNorm32': 'norm',
57
+ 'SparseLayerNorm32': 'norm',
58
+ 'SparseReLU': 'nonlinearity',
59
+ 'SparseSiLU': 'nonlinearity',
60
+ 'SparseGELU': 'nonlinearity',
61
+ 'SparseActivation': 'nonlinearity',
62
+ 'SparseLinear': 'linear',
63
+ 'sparse_scaled_dot_product_attention': 'attention',
64
+ 'SerializeMode': 'attention',
65
+ 'sparse_serialized_scaled_dot_product_self_attention': 'attention',
66
+ 'sparse_windowed_scaled_dot_product_self_attention': 'attention',
67
+ 'SparseMultiHeadAttention': 'attention',
68
+ 'SparseConv3d': 'conv',
69
+ 'SparseInverseConv3d': 'conv',
70
+ 'SparseDownsample': 'spatial',
71
+ 'SparseUpsample': 'spatial',
72
+ 'SparseSubdivide' : 'spatial'
73
+ }
74
+
75
+ __submodules = ['transformer']
76
+
77
+ __all__ = list(__attributes.keys()) + __submodules
78
+
79
+ def __getattr__(name):
80
+ if name not in globals():
81
+ if name in __attributes:
82
+ module_name = __attributes[name]
83
+ module = importlib.import_module(f".{module_name}", __name__)
84
+ globals()[name] = getattr(module, name)
85
+ elif name in __submodules:
86
+ module = importlib.import_module(f".{name}", __name__)
87
+ globals()[name] = module
88
+ else:
89
+ raise AttributeError(f"module {__name__} has no attribute {name}")
90
+ return globals()[name]
91
+
92
+
93
+ # For Pylance
94
+ if __name__ == '__main__':
95
+ from .basic import *
96
+ from .norm import *
97
+ from .nonlinearity import *
98
+ from .linear import *
99
+ from .attention import *
100
+ from .conv import *
101
+ from .spatial import *
102
+ import transformer