001/*******************************************************************************
002 * Copyright (c) 2016 Diamond Light Source Ltd. and others.
003 * All rights reserved. This program and the accompanying materials
004 * are made available under the terms of the Eclipse Public License v1.0
005 * which accompanies this distribution, and is available at
006 * http://www.eclipse.org/legal/epl-v10.html
007 *
008 * Contributors:
009 *     Diamond Light Source Ltd - initial API and implementation
010 *******************************************************************************/
011package org.eclipse.january.dataset;
012
013import java.lang.reflect.Array;
014import java.util.ArrayList;
015import java.util.Collection;
016import java.util.List;
017import java.util.SortedSet;
018import java.util.TreeSet;
019
020public class ShapeUtils {
021
022        private ShapeUtils() {
023        }
024
025        /**
026         * Calculate total number of items in given shape
027         * @param shape
028         * @return size
029         */
030        public static long calcLongSize(final int[] shape) {
031                if (shape == null) { // special case of null-shaped
032                        return 0;
033                }
034
035                final int rank = shape.length;
036                if (rank == 0) { // special case of zero-rank shape 
037                        return 1;
038                }
039        
040                double dsize = 1.0;
041                for (int i = 0; i < rank; i++) {
042                        // make sure the indexes isn't zero or negative
043                        if (shape[i] == 0) {
044                                return 0;
045                        } else if (shape[i] < 0) {
046                                throw new IllegalArgumentException(String.format(
047                                                "The %d-th is %d which is not allowed as it is negative", i, shape[i]));
048                        }
049        
050                        dsize *= shape[i];
051                }
052        
053                // check to see if the size is larger than an integer, i.e. we can't allocate it
054                if (dsize > Long.MAX_VALUE) {
055                        throw new IllegalArgumentException("Size of the dataset is too large to allocate");
056                }
057                return (long) dsize;
058        }
059
060        /**
061         * Calculate total number of items in given shape
062         * @param shape
063         * @return size
064         */
065        public static int calcSize(final int[] shape) {
066                long lsize = calcLongSize(shape);
067        
068                // check to see if the size is larger than an integer, i.e. we can't allocate it
069                if (lsize > Integer.MAX_VALUE) {
070                        throw new IllegalArgumentException("Size of the dataset is too large to allocate");
071                }
072                return (int) lsize;
073        }
074
075        /**
076         * Check if shapes are broadcast compatible
077         * 
078         * @param ashape
079         * @param bshape
080         * @return true if they are compatible
081         */
082        public static boolean areShapesBroadcastCompatible(final int[] ashape, final int[] bshape) {
083                if (ashape == null || bshape == null) {
084                        return ashape == bshape;
085                }
086
087                if (ashape.length < bshape.length) {
088                        return areShapesBroadcastCompatible(bshape, ashape);
089                }
090        
091                for (int a = ashape.length - bshape.length, b = 0; a < ashape.length && b < bshape.length; a++, b++) {
092                        if (ashape[a] != bshape[b] && ashape[a] != 1 && bshape[b] != 1) {
093                                return false;
094                        }
095                }
096        
097                return true;
098        }
099
100        /**
101         * Check if shapes are compatible, ignoring extra axes of length 1
102         * 
103         * @param ashape
104         * @param bshape
105         * @return true if they are compatible
106         */
107        public static boolean areShapesCompatible(final int[] ashape, final int[] bshape) {
108                if (ashape == null || bshape == null) {
109                        return ashape == bshape;
110                }
111
112                List<Integer> alist = new ArrayList<Integer>();
113        
114                for (int a : ashape) {
115                        if (a > 1) alist.add(a);
116                }
117        
118                final int imax = alist.size();
119                int i = 0;
120                for (int b : bshape) {
121                        if (b == 1)
122                                continue;
123                        if (i >= imax || b != alist.get(i++))
124                                return false;
125                }
126        
127                return i == imax;
128        }
129
130        /**
131         * Check if shapes are compatible but skip axis
132         * 
133         * @param ashape
134         * @param bshape
135         * @param axis
136         * @return true if they are compatible
137         */
138        public static boolean areShapesCompatible(final int[] ashape, final int[] bshape, final int axis) {
139                if (ashape == null || bshape == null) {
140                        return ashape == bshape;
141                }
142
143                if (ashape.length != bshape.length) {
144                        return false;
145                }
146        
147                final int rank = ashape.length;
148                for (int i = 0; i < rank; i++) {
149                        if (i != axis && ashape[i] != bshape[i]) {
150                                return false;
151                        }
152                }
153                return true;
154        }
155
156        /**
157         * Remove dimensions of 1 in given shape - from both ends only, if true
158         * 
159         * @param oshape
160         * @param onlyFromEnds
161         * @return newly squeezed shape (or original if unsqueezed)
162         */
163        public static int[] squeezeShape(final int[] oshape, boolean onlyFromEnds) {
164                int unitDims = 0;
165                int rank = oshape.length;
166                int start = 0;
167        
168                if (onlyFromEnds) {
169                        int i = rank - 1;
170                        for (; i >= 0; i--) {
171                                if (oshape[i] == 1) {
172                                        unitDims++;
173                                } else {
174                                        break;
175                                }
176                        }
177                        for (int j = 0; j <= i; j++) {
178                                if (oshape[j] == 1) {
179                                        unitDims++;
180                                } else {
181                                        start = j;
182                                        break;
183                                }
184                        }
185                } else {
186                        for (int i = 0; i < rank; i++) {
187                                if (oshape[i] == 1) {
188                                        unitDims++;
189                                }
190                        }
191                }
192        
193                if (unitDims == 0) {
194                        return oshape;
195                }
196        
197                int[] newDims = new int[rank - unitDims];
198                if (unitDims == rank)
199                        return newDims; // zero-rank dataset
200        
201                if (onlyFromEnds) {
202                        rank = newDims.length;
203                        for (int i = 0; i < rank; i++) {
204                                newDims[i] = oshape[i+start];
205                        }
206                } else {
207                        int j = 0;
208                        for (int i = 0; i < rank; i++) {
209                                if (oshape[i] > 1) {
210                                        newDims[j++] = oshape[i];
211                                        if (j >= newDims.length)
212                                                break;
213                                }
214                        }
215                }
216        
217                return newDims;
218        }
219
220        /**
221         * Remove dimension of 1 in given shape
222         * 
223         * @param oshape
224         * @param axis
225         * @return newly squeezed shape
226         */
227        public static int[] squeezeShape(final int[] oshape, int axis) {
228                if (oshape == null) {
229                        return null;
230                }
231
232                final int rank = oshape.length;
233                if (rank == 0) {
234                        return new int[0];
235                }
236                if (axis < 0) {
237                        axis += rank;
238                }
239                if (axis < 0 || axis >= rank) {
240                        throw new IllegalArgumentException("Axis argument is outside allowed range");
241                }
242                int[] nshape = new int[rank-1];
243                for (int i = 0; i < axis; i++) {
244                        nshape[i] = oshape[i];
245                }
246                for (int i = axis+1; i < rank; i++) {
247                        nshape[i-1] = oshape[i];
248                }
249                return nshape;
250        }
251
252        /**
253         * Get shape from object (array or list supported)
254         * @param obj
255         * @return shape can be null if obj is null
256         */
257        public static int[] getShapeFromObject(final Object obj) {
258                if (obj == null) {
259                        return null;
260                }
261
262                ArrayList<Integer> lshape = new ArrayList<Integer>();
263                getShapeFromObj(lshape, obj, 0);
264
265                final int rank = lshape.size();
266                final int[] shape = new int[rank];
267                for (int i = 0; i < rank; i++) {
268                        shape[i] = lshape.get(i);
269                }
270        
271                return shape;
272        }
273
274        /**
275         * Get shape from object
276         * @param ldims
277         * @param obj
278         * @param depth
279         * @return true if there is a possibility of differing lengths
280         */
281        private static boolean getShapeFromObj(final ArrayList<Integer> ldims, Object obj, int depth) {
282                if (obj == null)
283                        return true;
284        
285                if (obj instanceof List<?>) {
286                        List<?> jl = (List<?>) obj;
287                        int l = jl.size();
288                        updateShape(ldims, depth, l);
289                        for (int i = 0; i < l; i++) {
290                                Object lo = jl.get(i);
291                                if (!getShapeFromObj(ldims, lo, depth + 1)) {
292                                        break;
293                                }
294                        }
295                        return true;
296                }
297                Class<? extends Object> ca = obj.getClass().getComponentType();
298                if (ca != null) {
299                        final int l = Array.getLength(obj);
300                        updateShape(ldims, depth, l);
301                        if (DTypeUtils.isClassSupportedAsElement(ca)) {
302                                return true;
303                        }
304                        for (int i = 0; i < l; i++) {
305                                Object lo = Array.get(obj, i);
306                                if (!getShapeFromObj(ldims, lo, depth + 1)) {
307                                        break;
308                                }
309                        }
310                        return true;
311                } else if (obj instanceof IDataset) {
312                        int[] s = ((IDataset) obj).getShape();
313                        for (int i = 0; i < s.length; i++) {
314                                updateShape(ldims, depth++, s[i]);
315                        }
316                        return true;
317                } else {
318                        return false; // not an array of any type
319                }
320        }
321
322        private static void updateShape(final ArrayList<Integer> ldims, final int depth, final int l) {
323                if (depth >= ldims.size()) {
324                        ldims.add(l);
325                } else if (l > ldims.get(depth)) {
326                        ldims.set(depth, l);
327                }
328        }
329
330        /**
331         * Get n-D position from given index
332         * @param n index
333         * @param shape
334         * @return n-D position
335         */
336        public static int[] getNDPositionFromShape(int n, int[] shape) {
337                if (shape == null) {
338                        return null;
339                }
340
341                int rank = shape.length;
342                if (rank == 0) {
343                        return new int[0];
344                }
345
346                if (rank == 1) {
347                        return new int[] { n };
348                }
349
350                int[] output = new int[rank];
351                for (rank--; rank > 0; rank--) {
352                        output[rank] = n % shape[rank];
353                        n /= shape[rank];
354                }
355                output[0] = n;
356        
357                return output;
358        }
359
360        /**
361         * Get flattened view index of given position 
362         * @param shape
363         * @param pos
364         *            the integer array specifying the n-D position
365         * @return the index on the flattened dataset
366         */
367        public static int getFlat1DIndex(final int[] shape, final int[] pos) {
368                final int imax = pos.length;
369                if (imax == 0) {
370                        return 0;
371                }
372        
373                return AbstractDataset.get1DIndexFromShape(shape, pos);
374        }
375
376        /**
377         * This function takes a dataset and checks its shape against another dataset. If they are both of the same size,
378         * then this returns with no error, if there is a problem, then an error is thrown.
379         * 
380         * @param g
381         *            The first dataset to be compared
382         * @param h
383         *            The second dataset to be compared
384         * @throws IllegalArgumentException
385         *             This will be thrown if there is a problem with the compatibility
386         */
387        public static void checkCompatibility(final ILazyDataset g, final ILazyDataset h) throws IllegalArgumentException {
388                if (!areShapesCompatible(g.getShape(), h.getShape())) {
389                        throw new IllegalArgumentException("Shapes do not match");
390                }
391        }
392
393        /**
394         * Check that axis is in range [-rank,rank)
395         * 
396         * @param rank
397         * @param axis
398         * @return sanitized axis in range [0, rank)
399         * @since 2.1
400         */
401        public static int checkAxis(int rank, int axis) {
402                if (axis < 0) {
403                        axis += rank;
404                }
405        
406                if (axis < 0 || axis >= rank) {
407                        throw new IllegalArgumentException("Axis " + axis + " given is out of range [0, " + rank + ")");
408                }
409                return axis;
410        }
411
412        private static int[] convert(Collection<Integer> list) {
413                int[] array = new int[list.size()];
414                int i = 0;
415                for (Integer l : list) {
416                        array[i++] = l;
417                }
418                return array;
419        }
420
421        /**
422         * Check that all axes are in range [-rank,rank)
423         * @param rank
424         * @param axes
425         * @return sanitized axes in range [0, rank) and sorted in increasing order
426         * @since 2.2
427         */
428        public static int[] checkAxes(int rank, int... axes) {
429                return convert(sanitizeAxes(rank, axes));
430        }
431
432        /**
433         * Check that all axes are in range [-rank,rank)
434         * @param rank
435         * @param axes
436         * @return sanitized axes in range [0, rank) and sorted in increasing order
437         * @since 2.2
438         */
439        private static SortedSet<Integer> sanitizeAxes(int rank, int... axes) {
440                SortedSet<Integer> nAxes = new TreeSet<>(); 
441                for (int i = 0; i < axes.length; i++) {
442                        nAxes.add(checkAxis(rank, axes[i]));
443                }
444
445                return nAxes;
446        }
447
448        /**
449         * @param rank
450         * @param axes
451         * @return remaining axes not given by input
452         * @since 2.2
453         */
454        public static int[] getRemainingAxes(int rank, int... axes) {
455                SortedSet<Integer> nAxes = sanitizeAxes(rank, axes);
456
457                int[] remains = new int[rank - axes.length];
458                int j = 0;
459                for (int i = 0; i < rank; i++) {
460                        if (!nAxes.contains(i)) {
461                                remains[j++] = i;
462                        }
463                }
464                return remains;
465        }
466
467        /**
468         * Remove axes from shape
469         * @param shape
470         * @param axes
471         * @return reduced shape
472         * @since 2.2
473         */
474        public static int[] reduceShape(int[] shape, int... axes) {
475                int[] remain = getRemainingAxes(shape.length, axes);
476                for (int i = 0; i < remain.length; i++) {
477                        int a = remain[i];
478                        remain[i] = shape[a];
479                }
480                return remain;
481        }
482
483        /**
484         * Set reduced axes to 1
485         * @param shape
486         * @param axes
487         * @return shape with same rank
488         * @since 2.2
489         */
490        public static int[] getReducedShapeKeepRank(int[] shape, int... axes) {
491                int[] keep = shape.clone();
492                axes = checkAxes(shape.length, axes);
493                for (int i : axes) {
494                        keep[i] = 1;
495                }
496                return keep;
497        }
498}