qemu-devel
[Top][All Lists]
Advanced

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

Re: [PATCH v1 08/11] target/arm: Implement bfloat16 matrix multiply accu


From: Peter Maydell
Subject: Re: [PATCH v1 08/11] target/arm: Implement bfloat16 matrix multiply accumulate
Date: Tue, 18 May 2021 13:37:51 +0100

On Sat, 17 Apr 2021 at 01:00, Richard Henderson
<richard.henderson@linaro.org> wrote:
>
> This is BFMMLA for both AArch64 AdvSIMD and SVE,
> and VMMLA.BF16 for AArch32 NEON.
>
> Signed-off-by: Richard Henderson <richard.henderson@linaro.org>

> +void HELPER(gvec_bfmmla)(void *vd, void *vn, void *vm, void *va, uint32_t 
> desc)
> +{
> +    intptr_t s, opr_sz = simd_oprsz(desc);
> +    float32 *d = vd, *a = va;
> +    uint32_t *n = vn, *m = vm;
> +
> +    for (s = 0; s < opr_sz / 4; s += 4) {
> +        float32 sum00, sum01, sum10, sum11;
> +
> +        /*
> +         * Process the entire segment at once, writing back the
> +         * results only after we've consumed all of the inputs.
> +         *
> +         * Key to indicies by column:

"indices"

> +         *               i   j           i   k             j   k
> +         */
> +        sum00 = a[s + H4(0 + 0)];
> +        sum00 = bfdotadd(sum00, n[s + H4(0 + 0)], m[s + H4(0 + 0)]);
> +        sum00 = bfdotadd(sum00, n[s + H4(0 + 1)], m[s + H4(0 + 1)]);

I can't make these indices match up with the arm arm pseudocode ones,
which index by "4*i + 2*k + 0" and "4*i + 2*k + 1", not "2*i + k";
are we hiding a division by 2 somewhere?

> +
> +        sum01 = a[s + H4(0 + 1)];
> +        sum01 = bfdotadd(sum01, n[s + H4(0 + 0)], m[s + H4(2 + 0)]);
> +        sum01 = bfdotadd(sum01, n[s + H4(0 + 1)], m[s + H4(2 + 1)]);
> +
> +        sum10 = a[s + H4(2 + 0)];
> +        sum10 = bfdotadd(sum10, n[s + H4(2 + 0)], m[s + H4(0 + 0)]);
> +        sum10 = bfdotadd(sum10, n[s + H4(2 + 1)], m[s + H4(0 + 1)]);
> +
> +        sum11 = a[s + H4(2 + 1)];
> +        sum11 = bfdotadd(sum11, n[s + H4(2 + 0)], m[s + H4(2 + 0)]);
> +        sum11 = bfdotadd(sum11, n[s + H4(2 + 1)], m[s + H4(2 + 1)]);
> +
> +        d[s + H4(0 + 0)] = sum00;
> +        d[s + H4(0 + 1)] = sum01;
> +        d[s + H4(2 + 0)] = sum10;
> +        d[s + H4(2 + 1)] = sum11;
> +    }
> +    clear_tail(d, opr_sz, simd_maxsz(desc));

Otherwise
Reviewed-by: Peter Maydell <peter.maydell@linaro.org>

thanks
-- PMM



reply via email to

[Prev in Thread] Current Thread [Next in Thread]