16
16
#include < migraphx/instruction.hpp>
17
17
#include < migraphx/config.hpp>
18
18
#include < migraphx/onnx.hpp>
19
+ #include < migraphx/pad_calc.hpp>
19
20
20
21
namespace migraphx {
21
22
inline namespace MIGRAPHX_INLINE_NS {
@@ -302,6 +303,24 @@ struct onnx_parser
302
303
return curr_ins;
303
304
}
304
305
306
+ template <class Op >
307
+ void check_asym_padding (instruction_ref& ins,
308
+ std::vector<int64_t >& padding,
309
+ Op& op,
310
+ float pad_val = 0 )
311
+ {
312
+ if (padding[0 ] != padding[2 ] || padding[1 ] != padding[3 ])
313
+ {
314
+ padding = {0 , 0 , padding[0 ], padding[1 ], 0 , 0 , padding[2 ], padding[3 ]};
315
+ ins = prog.add_instruction (op::pad{padding, pad_val}, ins);
316
+ }
317
+ else
318
+ {
319
+ op.padding [0 ] = padding[0 ];
320
+ op.padding [1 ] = padding[1 ];
321
+ }
322
+ }
323
+
305
324
instruction_ref parse_clip (const std::string&,
306
325
const attribute_map& attributes,
307
326
std::vector<instruction_ref> args)
@@ -424,7 +443,8 @@ struct onnx_parser
424
443
parse_conv (const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
425
444
{
426
445
Op op;
427
- auto l0 = args[0 ];
446
+ auto l0 = args[0 ];
447
+ auto weights = args[1 ];
428
448
if (contains (attributes, " pads" ))
429
449
{
430
450
if (contains (attributes, " auto_pad" ))
@@ -441,17 +461,7 @@ struct onnx_parser
441
461
{
442
462
MIGRAPHX_THROW (" padding should have 4 values" );
443
463
}
444
- if (padding[0 ] != padding[2 ] || padding[1 ] != padding[3 ])
445
- {
446
- // insert zeros for pad op (args[0] has 4 dims)
447
- padding = {0 , 0 , padding[0 ], padding[1 ], 0 , 0 , padding[2 ], padding[3 ]};
448
- l0 = prog.add_instruction (op::pad{padding}, l0);
449
- }
450
- else
451
- {
452
- op.padding [0 ] = padding[0 ];
453
- op.padding [1 ] = padding[1 ];
454
- }
464
+ check_asym_padding (l0, padding, op);
455
465
}
456
466
if (contains (attributes, " strides" ))
457
467
{
@@ -471,7 +481,19 @@ struct onnx_parser
471
481
472
482
if (s.find (" SAME" ) != std::string::npos)
473
483
{
474
- op.padding_mode = op::padding_mode_t ::same;
484
+ op.padding_mode = op::padding_mode_t ::same;
485
+ std::vector<size_t > weight_dims = weights->get_shape ().lens ();
486
+ size_t weight_h = weight_dims[2 ];
487
+ size_t weight_w = weight_dims[3 ];
488
+
489
+ auto input_dims = l0->get_shape ().lens ();
490
+ std::vector<int64_t > padding (input_dims.size ());
491
+ calculate_padding (
492
+ 0 , padding, input_dims[2 ], op.stride [0 ], op.dilation [0 ], weight_h);
493
+ calculate_padding (
494
+ 1 , padding, input_dims[3 ], op.stride [1 ], op.dilation [1 ], weight_w);
495
+
496
+ check_asym_padding (l0, padding, op);
475
497
}
476
498
}
477
499
if (contains (attributes, " group" ))
@@ -618,27 +640,10 @@ struct onnx_parser
618
640
{
619
641
MIGRAPHX_THROW (" PARSE_POOLING: padding should have 4 values" );
620
642
}
621
- if (padding[0 ] != padding[2 ] || padding[1 ] != padding[3 ])
622
- {
623
- // insert zeros for pad op (args[0] has 4 dims)
624
- padding = {0 , 0 , padding[0 ], padding[1 ], 0 , 0 , padding[2 ], padding[3 ]};
625
- // MaxPool
626
- if (op.mode == " max" )
627
- {
628
- l0 = prog.add_instruction (
629
- op::pad{padding, std::numeric_limits<float >::lowest ()}, l0);
630
- }
631
- // AveragePool
632
- else
633
- {
634
- l0 = prog.add_instruction (op::pad{padding}, l0);
635
- }
636
- }
637
- else
638
- {
639
- op.padding [0 ] = padding[0 ];
640
- op.padding [1 ] = padding[1 ];
641
- }
643
+ float pad_val = 0 ;
644
+ if (op.mode == " max" )
645
+ pad_val = std::numeric_limits<float >::lowest ();
646
+ check_asym_padding (l0, padding, op, pad_val);
642
647
}
643
648
644
649
if (contains (attributes, " strides" ))
0 commit comments