| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from ..octree import DfsOctree as Octree |
|
|
|
|
| class Strivec(Octree): |
| def __init__( |
| self, |
| resolution: int, |
| aabb: list, |
| sh_degree: int = 0, |
| rank: int = 8, |
| dim: int = 8, |
| device: str = "cuda", |
| ): |
| assert np.log2(resolution) % 1 == 0, "Resolution must be a power of 2" |
| self.resolution = resolution |
| depth = int(np.round(np.log2(resolution))) |
| super().__init__( |
| depth=depth, |
| aabb=aabb, |
| sh_degree=sh_degree, |
| primitive="trivec", |
| primitive_config={"rank": rank, "dim": dim}, |
| device=device, |
| ) |
|
|