Files
carbon e25f20f7a3 add cviruntime
commit 3f4938648950a7f3bf9a19c320ca9fae7c52de20
Author: sophgo-forum-service <forum_service@sophgo.com>
Date:   Mon May 13 13:44:23 2024 +0800

    [feat] cviruntime opensource for cv18xx soc.

    - a4b6a3, add cumsum and gatherelements_pt.
2024-05-31 11:51:34 +08:00

25 lines
721 B
Python

import torch
import torch.nn as nn
import numpy as np
class Light(nn.Module):
def __init__(self, kernel_size=15):
super().__init__()
self.pool = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2)
def forward(self, x):
bkg = self.pool(x)
x = torch.where(bkg > x, 255 - (bkg - x), torch.tensor(255.))
return x
if __name__ == '__main__':
net = Light()
shape = (1, 1, 40, 70)
x = torch.randint(255, shape)
x = x.to(torch.float32)
np.savez("light_x.npz", x=x.numpy().astype(np.uint8))
print("x", x.shape, x)
out = net(x)
np.savez("light_out.npz", out=out.numpy().astype(np.uint8))
print("out", out.shape, out)