exportable Hardswish() implementation
This commit is contained in:
parent
fd71fe8451
commit
71209a6099
|
|
@ -10,6 +10,13 @@ class Swish(nn.Module): #
|
||||||
return x * torch.sigmoid(x)
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Hardswish(nn.Module): # alternative to nn.Hardswish() for export
|
||||||
|
@staticmethod
|
||||||
|
def forward(x):
|
||||||
|
# return x * F.hardsigmoid(x)
|
||||||
|
return x * F.hardtanh(x + 3, 0., 6.) / 6.
|
||||||
|
|
||||||
|
|
||||||
class MemoryEfficientSwish(nn.Module):
|
class MemoryEfficientSwish(nn.Module):
|
||||||
class F(torch.autograd.Function):
|
class F(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue