cost of torch multiply

以下代码测试了pytorch中不同类型的tensor进行乘法运算的速度快慢([5, 5] * [5 * 10]):

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
Wrote profile results to pytorch_type.py.lprof
Timer unit: 1e-06 s

Total time: 0.000321581 s
File: pytorch_type.py
Function: fp_32_matmul at line 3

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     3                                           @profile
     4                                           def fp_32_matmul(x, y):
     5         1        321.6    321.6    100.0      return torch.matmul(x, y)

Total time: 0.000299756 s
File: pytorch_type.py
Function: int_32_matmul at line 7

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     7                                           @profile
     8                                           def int_32_matmul(x, y):
     9         1        251.7    251.7     84.0      x = x.to(torch.int32)
    10         1         11.0     11.0      3.7      y = y.to(torch.int32)
    11         1         37.0     37.0     12.3      return torch.matmul(x, y)

Total time: 3.0522e-05 s
File: pytorch_type.py
Function: int_16_matmul at line 13

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    13                                           @profile
    14                                           def int_16_matmul(x, y):
    15         1         11.2     11.2     36.6      x = x.to(torch.int16)
    16         1          7.7      7.7     25.2      y = y.to(torch.int16)
    17         1         11.7     11.7     38.3      return torch.matmul(x, y)

Total time: 3.5534e-05 s
File: pytorch_type.py
Function: int_8_matmul at line 19

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    19                                           @profile
    20                                           def int_8_matmul(x, y):
    21         1         16.5     16.5     46.3      x = x.to(torch.int8)
    22         1          8.0      8.0     22.4      y = y.to(torch.int8)
    23         1         11.1     11.1     31.2      return torch.matmul(x, y)

Total time: 3.3385e-05 s
File: pytorch_type.py
Function: uint_8_matmul at line 25

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    25                                           @profile
    26                                           def uint_8_matmul(x, y):
    27         1          9.9      9.9     29.7      x = x.to(torch.uint8)
    28         1         12.3     12.3     36.9      y = y.to(torch.uint8)
    29         1         11.2     11.2     33.4      return torch.matmul(x, y)

看上去整数乘法要比浮点数矩阵乘法快很多。但如果增大矩阵[3, 500, 768] * [768, 768]:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
Wrote profile results to pytorch_type.py.lprof
Timer unit: 1e-06 s

Total time: 0.0055994 s
File: pytorch_type.py
Function: fp_32_matmul at line 3

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     3                                           @profile
     4                                           def fp_32_matmul(x, y):
     5         1       5599.4   5599.4    100.0      return x @ y

Total time: 0.172675 s
File: pytorch_type.py
Function: int_32_matmul at line 7

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     7                                           @profile
     8                                           def int_32_matmul(x, y):
     9         1        937.7    937.7      0.5      x = x.to(torch.int32)
    10         1        381.6    381.6      0.2      y = y.to(torch.int32)
    11         1     171355.3 171355.3     99.2      return x @ y

Total time: 0.0797359 s
File: pytorch_type.py
Function: int_16_matmul at line 13

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    13                                           @profile
    14                                           def int_16_matmul(x, y):
    15         1        314.3    314.3      0.4      x = x.to(torch.int16)
    16         1        106.3    106.3      0.1      y = y.to(torch.int16)
    17         1      79315.3  79315.3     99.5      return x @ y

Total time: 0.048643 s
File: pytorch_type.py
Function: int_8_matmul at line 19

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    19                                           @profile
    20                                           def int_8_matmul(x, y):
    21         1        368.0    368.0      0.8      x = x.to(torch.int8)
    22         1        123.3    123.3      0.3      y = y.to(torch.int8)
    23         1      48151.7  48151.7     99.0      return x @ y

Total time: 0.0488203 s
File: pytorch_type.py
Function: uint_8_matmul at line 25

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    25                                           @profile
    26                                           def uint_8_matmul(x, y):
    27         1        293.1    293.1      0.6      x = x.to(torch.uint8)
    28         1         96.8     96.8      0.2      y = y.to(torch.uint8)
    29         1      48430.4  48430.4     99.2      return x @ y

在这种情况下,浮点数乘法又要比整数乘法快很多,合理推测是pytorch对浮点数乘法的底层进行了优化。

但如果将推理引擎由CPU换为GPU时,将会发生以下现象:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
Traceback (most recent call last):
  File "/home/yujin/anaconda3/envs/quant/bin/kernprof", line 8, in <module>
    sys.exit(main())
  File "/home/yujin/anaconda3/envs/quant/lib/python3.10/site-packages/kernprof.py", line 264, in main
    execfile(script_file, ns, ns)
  File "/home/yujin/anaconda3/envs/quant/lib/python3.10/site-packages/kernprof.py", line 32, in execfile
    exec(compile(f.read(), filename, 'exec'), globals, locals)
  File "pytorch_type.py", line 37, in <module>
    int_32_matmul(tensor_fp32, tensor_fp32_2)
  File "/home/yujin/anaconda3/envs/quant/lib/python3.10/site-packages/line_profiler/line_profiler.py", line 130, in wrapper
    result = func(*args, **kwds)
  File "pytorch_type.py", line 11, in int_32_matmul
    return torch.matmul(x, y)
RuntimeError: "addmm_cuda" not implemented for 'Int'

这说明目前的CUDA不支持直接加速整数矩阵乘法运算。

Built with Hugo
Theme Stack designed by Jimmy
visitors: total visits: time(s) reads: time(s)