Skip to content

Commit d1e945d

Browse files
kahmed10mvermeulen
andauthored
Add same padding mode for onnx (#456)
* fix pad calc * add padding calc and test * formatting * made asym generic function * formatting Co-authored-by: mvermeulen <5479696+mvermeulen@users.noreply.github.com>
1 parent 63d8e40 commit d1e945d

File tree

4 files changed

+94
-34
lines changed

4 files changed

+94
-34
lines changed

‎src/onnx/onnx.cpp

+39-34
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <migraphx/instruction.hpp>
1717
#include <migraphx/config.hpp>
1818
#include <migraphx/onnx.hpp>
19+
#include <migraphx/pad_calc.hpp>
1920

2021
namespace migraphx {
2122
inline namespace MIGRAPHX_INLINE_NS {
@@ -302,6 +303,24 @@ struct onnx_parser
302303
return curr_ins;
303304
}
304305

306+
template <class Op>
307+
void check_asym_padding(instruction_ref& ins,
308+
std::vector<int64_t>& padding,
309+
Op& op,
310+
float pad_val = 0)
311+
{
312+
if(padding[0] != padding[2] || padding[1] != padding[3])
313+
{
314+
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
315+
ins = prog.add_instruction(op::pad{padding, pad_val}, ins);
316+
}
317+
else
318+
{
319+
op.padding[0] = padding[0];
320+
op.padding[1] = padding[1];
321+
}
322+
}
323+
305324
instruction_ref parse_clip(const std::string&,
306325
const attribute_map& attributes,
307326
std::vector<instruction_ref> args)
@@ -424,7 +443,8 @@ struct onnx_parser
424443
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
425444
{
426445
Op op;
427-
auto l0 = args[0];
446+
auto l0 = args[0];
447+
auto weights = args[1];
428448
if(contains(attributes, "pads"))
429449
{
430450
if(contains(attributes, "auto_pad"))
@@ -441,17 +461,7 @@ struct onnx_parser
441461
{
442462
MIGRAPHX_THROW("padding should have 4 values");
443463
}
444-
if(padding[0] != padding[2] || padding[1] != padding[3])
445-
{
446-
// insert zeros for pad op (args[0] has 4 dims)
447-
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
448-
l0 = prog.add_instruction(op::pad{padding}, l0);
449-
}
450-
else
451-
{
452-
op.padding[0] = padding[0];
453-
op.padding[1] = padding[1];
454-
}
464+
check_asym_padding(l0, padding, op);
455465
}
456466
if(contains(attributes, "strides"))
457467
{
@@ -471,7 +481,19 @@ struct onnx_parser
471481

472482
if(s.find("SAME") != std::string::npos)
473483
{
474-
op.padding_mode = op::padding_mode_t::same;
484+
op.padding_mode = op::padding_mode_t::same;
485+
std::vector<size_t> weight_dims = weights->get_shape().lens();
486+
size_t weight_h = weight_dims[2];
487+
size_t weight_w = weight_dims[3];
488+
489+
auto input_dims = l0->get_shape().lens();
490+
std::vector<int64_t> padding(input_dims.size());
491+
calculate_padding(
492+
0, padding, input_dims[2], op.stride[0], op.dilation[0], weight_h);
493+
calculate_padding(
494+
1, padding, input_dims[3], op.stride[1], op.dilation[1], weight_w);
495+
496+
check_asym_padding(l0, padding, op);
475497
}
476498
}
477499
if(contains(attributes, "group"))
@@ -618,27 +640,10 @@ struct onnx_parser
618640
{
619641
MIGRAPHX_THROW("PARSE_POOLING: padding should have 4 values");
620642
}
621-
if(padding[0] != padding[2] || padding[1] != padding[3])
622-
{
623-
// insert zeros for pad op (args[0] has 4 dims)
624-
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
625-
// MaxPool
626-
if(op.mode == "max")
627-
{
628-
l0 = prog.add_instruction(
629-
op::pad{padding, std::numeric_limits<float>::lowest()}, l0);
630-
}
631-
// AveragePool
632-
else
633-
{
634-
l0 = prog.add_instruction(op::pad{padding}, l0);
635-
}
636-
}
637-
else
638-
{
639-
op.padding[0] = padding[0];
640-
op.padding[1] = padding[1];
641-
}
643+
float pad_val = 0;
644+
if(op.mode == "max")
645+
pad_val = std::numeric_limits<float>::lowest();
646+
check_asym_padding(l0, padding, op, pad_val);
642647
}
643648

644649
if(contains(attributes, "strides"))

‎test/onnx/conv_autopad_same_test.onnx

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
conv_autopad_same_test:�
2+
J
3+
0
4+
12"Conv*
5+
auto_pad"SAME�*
6+
dilations@@�*
7+
strides@@�conv_autopad_same_testZ
8+
0
9+

10+

11+

12+

13+
 Z
14+
1
15+

16+

17+

18+

19+
b
20+
2
21+

22+

23+

24+

25+
 B

‎test/onnx/gen_onnx.py

+16
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,22 @@ def conv_autopad_fail_test():
492492
return ([node], [x, y], [out])
493493

494494

495+
@onnx_test
496+
def conv_autopad_same_test():
497+
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 32, 32])
498+
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 3, 3])
499+
out = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 1, 32, 32])
500+
501+
node = onnx.helper.make_node('Conv',
502+
inputs=['0', '1'],
503+
outputs=['2'],
504+
dilations=[1, 1],
505+
strides=[1, 1],
506+
auto_pad='SAME')
507+
508+
return ([node], [x, y], [out])
509+
510+
495511
@onnx_test
496512
def conv_bias_test():
497513
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 32, 32])

‎test/onnx/onnx_test.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,20 @@ TEST_CASE(conv_autopad_fail_test)
341341
EXPECT(test::throws([&] { optimize_onnx("conv_autopad_fail_test.onnx"); }));
342342
}
343343

344+
TEST_CASE(conv_autopad_same_test)
345+
{
346+
migraphx::program p;
347+
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
348+
auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {1, 3, 3, 3}});
349+
migraphx::op::convolution op;
350+
op.padding = {1, 1};
351+
op.padding_mode = migraphx::op::padding_mode_t::same;
352+
p.add_instruction(op, l0, l1);
353+
354+
auto prog = optimize_onnx("conv_autopad_same_test.onnx");
355+
EXPECT(p == prog);
356+
}
357+
344358
TEST_CASE(conv_bias_test)
345359
{
346360
migraphx::program p;

0 commit comments

Comments
 (0)