Advanced Field Construction
With FieldManager
, you can tweak how fields are constructed in order to gain performance improvement in kernel calculations.
By supplying a customized FieldManager
when registering a field, you can construct a field however you want.
WARNING:
- If you don't know why constructing fields differently can improve performance, don't use this feature.
- If you don't know how to construct fields differently, please refer to Taichi field documentation.
Example
In auxiliary.py
, FieldManager
is defined as an abstract class as
class FieldManager(ABC):
"""
FieldManagers enable potential flexible field constructions and manipulations.
For example, instead of ordinarily layout-ting a multidimensional field,
you can do hierarchical placements for fields, which may gives dramatic performance improvements
based on applications. Since hierarchical fields may not have the same shape of input tensor,
it's YOUR responsibility to write a FieldManager that can correctly transform field values into/from tensors
"""
@abstractmethod
def construct_field(self,
fields_builder: ti.FieldsBuilder,
concrete_tensor_shape: Tuple[int, ...],
needs_grad: bool) -> Union[ScalarField, MatrixField]:
pass
@abstractmethod
def to_tensor(self, field: Union[ScalarField, MatrixField]) -> torch.Tensor:
pass
@abstractmethod
def grad_to_tensor(self, grad_field: Union[ScalarField, MatrixField]) -> torch.Tensor:
pass
@abstractmethod
def from_tensor(self, field: Union[ScalarField, MatrixField], tensor: torch.Tensor):
pass
@abstractmethod
def grad_from_tensor(self, grad_field: Union[ScalarField, MatrixField], tensor: torch.Tensor):
pass
One example is the DefaultFieldManger
in tube.py
defined as:
class DefaultFieldManager(FieldManager):
"""
Default field manager which layouts data in tensors by constructing fields
with the ordinary multidimensional array layout
"""
def __init__(self,
dtype: TiDataType,
complex_dtype: bool,
device: torch.device):
self.dtype: TiDataType = dtype
self.complex_dtype: bool = complex_dtype
self.device: torch.device = device
def construct_field(self,
fields_builder: ti.FieldsBuilder,
concrete_tensor_shape: Tuple[int, ...],
needs_grad: bool) -> Union[ScalarField, MatrixField]:
assert not fields_builder.finalized
if self.complex_dtype:
field = ti.Vector.field(2, dtype=self.dtype, needs_grad=needs_grad)
else:
field = ti.field(self.dtype, needs_grad=needs_grad)
if needs_grad:
fields_builder \
.dense(axes(*range(len(concrete_tensor_shape))), concrete_tensor_shape) \
.place(field, field.grad)
else:
fields_builder.dense(axes(*range(len(concrete_tensor_shape))), concrete_tensor_shape).place(field)
return field
def to_tensor(self, field: Union[ScalarField, MatrixField]) -> torch.Tensor:
tensor = field.to_torch(device=self.device)
if self.complex_dtype:
tensor = torch.view_as_complex(tensor)
return tensor
def grad_to_tensor(self, grad_field: Union[ScalarField, MatrixField]) -> torch.Tensor:
tensor = grad_field.to_torch(device=self.device)
if self.complex_dtype:
tensor = torch.view_as_complex(tensor)
return tensor
def from_tensor(self, field: Union[ScalarField, MatrixField], tensor: torch.Tensor):
if self.complex_dtype:
tensor = torch.view_as_real(tensor)
field.from_torch(tensor)
def grad_from_tensor(self, grad_field: Union[ScalarField, MatrixField], tensor: torch.Tensor):
if self.complex_dtype:
tensor = torch.view_as_real(tensor)
grad_field.from_torch(tensor)