Bug 555712: Struct and union issues
[gnome.gobject-introspection] / giscanner / transformer.py
1 # -*- Mode: Python -*-
2 # GObject-Introspection - a framework for introspecting GObject libraries
3 # Copyright (C) 2008  Johan Dahlin
4 #
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; either version 2
8 # of the License, or (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
18 # 02110-1301, USA.
19 #
20
21 import os
22
23 from giscanner.ast import (Callback, Enum, Function, Namespace, Member,
24                            Parameter, Return, Array, Struct, Field,
25                            Type, Alias, Interface, Class, Node, Union,
26                            List, Map, Varargs, type_name_from_ctype,
27                            type_names, default_array_types)
28 from giscanner.config import DATADIR
29 from .glibast import GLibBoxed
30 from giscanner.sourcescanner import (
31     SourceSymbol, ctype_name, CTYPE_POINTER,
32     CTYPE_BASIC_TYPE, CTYPE_UNION, CTYPE_ARRAY, CTYPE_TYPEDEF,
33     CTYPE_VOID, CTYPE_ENUM, CTYPE_FUNCTION, CTYPE_STRUCT,
34     CSYMBOL_TYPE_FUNCTION, CSYMBOL_TYPE_TYPEDEF, CSYMBOL_TYPE_STRUCT,
35     CSYMBOL_TYPE_ENUM, CSYMBOL_TYPE_UNION, CSYMBOL_TYPE_OBJECT,
36     CSYMBOL_TYPE_MEMBER, CSYMBOL_TYPE_ELLIPSIS)
37 from .odict import odict
38 from .utils import strip_common_prefix, to_underscores
39
40 _xdg_data_dirs = [x for x in os.environ.get('XDG_DATA_DIRS', '').split(':') \
41                       + [DATADIR, '/usr/share'] if x]
42
43
44 class SkipError(Exception):
45     pass
46
47
48 class Names(object):
49     names = property(lambda self: self._names)
50     aliases = property(lambda self: self._aliases)
51     type_names = property(lambda self: self._type_names)
52     ctypes = property(lambda self: self._ctypes)
53
54     def __init__(self):
55         super(Names, self).__init__()
56         self._names = odict() # Maps from GIName -> (namespace, node)
57         self._aliases = {} # Maps from GIName -> GIName
58         self._type_names = {} # Maps from GTName -> (namespace, node)
59         self._ctypes = {} # Maps from CType -> (namespace, node)
60
61
62 class Transformer(object):
63
64     def __init__(self, generator, namespace_name):
65         self.generator = generator
66         self._namespace = Namespace(namespace_name)
67         self._names = Names()
68         self._typedefs_ns = {}
69         self._strip_prefix = ''
70         self._includes = set()
71         self._includepaths = []
72
73         self._list_ctypes = []
74         self._map_ctypes = []
75
76     def get_names(self):
77         return self._names
78
79     def get_includes(self):
80         return self._includes
81
82     def set_container_types(self, list_ctypes, map_ctypes):
83         self._list_ctypes = list_ctypes
84         self._map_ctypes = map_ctypes
85
86     def set_strip_prefix(self, strip_prefix):
87         self._strip_prefix = strip_prefix
88
89     def parse(self):
90         nodes = []
91         for symbol in self.generator.get_symbols():
92             node = self._traverse_one(symbol)
93             self._add_node(node)
94         return self._namespace
95
96     def register_include(self, filename):
97         (path, suffix) = os.path.splitext(filename)
98         name = os.path.basename(path)
99         if name in self._includes:
100             return
101         if suffix == '':
102             suffix = '.gir'
103             filename = path + suffix
104         if suffix == '.gir':
105             source = filename
106             if not os.path.exists(filename):
107                 searchdirs = [os.path.join(d, 'gir') for d \
108                                   in _xdg_data_dirs]
109                 searchdirs.extend(self._includepaths)
110                 source = None
111                 for d in searchdirs:
112                     source = os.path.join(d, filename)
113                     if os.path.exists(source):
114                         break
115                     source = None
116             if not source:
117                 raise ValueError("Couldn't find include %r (search path: %r)"\
118                                      % (filename, searchdirs))
119             d = os.path.dirname(source)
120             if d not in self._includepaths:
121                 self._includepaths.append(d)
122             self._includes.add(name)
123             from .girparser import GIRParser
124             parser = GIRParser(source)
125         else:
126             raise NotImplementedError(filename)
127         for include in parser.get_includes():
128             self.register_include(include)
129         nsname = parser.get_namespace().name
130         for node in parser.get_namespace().nodes:
131             if isinstance(node, Alias):
132                 self._names.aliases[node.name] = (nsname, node)
133             elif isinstance(node, (GLibBoxed, Interface, Class)):
134                 self._names.type_names[node.type_name] = (nsname, node)
135             self._names.names[node.name] = (nsname, node)
136             if hasattr(node, 'ctype'):
137                 self._names.ctypes[node.ctype] = (nsname, node)
138             elif hasattr(node, 'symbol'):
139                 self._names.ctypes[node.symbol] = (nsname, node)
140
141     def strip_namespace_object(self, name):
142         prefix = self._namespace.name.lower()
143         if len(name) > len(prefix) and name.lower().startswith(prefix):
144             return name[len(prefix):]
145         return self._remove_prefix(name)
146
147     # Private
148
149     def _add_node(self, node):
150         if node is None:
151             return
152         if node.name.startswith('_'):
153             return
154         self._namespace.nodes.append(node)
155         self._names.names[node.name] = (None, node)
156
157     def _strip_namespace_func(self, name):
158         prefix = self._namespace.name.lower() + '_'
159         if name.lower().startswith(prefix):
160             name = name[len(prefix):]
161         else:
162             prefix = to_underscores(self._namespace.name).lower() + '_'
163             if name.lower().startswith(prefix):
164                 name = name[len(prefix):]
165         return self._remove_prefix(name, isfunction=True)
166
167     def _remove_prefix(self, name, isfunction=False):
168         # when --strip-prefix=g:
169         #   GHashTable -> HashTable
170         #   g_hash_table_new -> hash_table_new
171         prefix = self._strip_prefix.lower()
172         if isfunction:
173             prefix += '_'
174         if name.lower().startswith(prefix):
175             name = name[len(prefix):]
176
177         while name.startswith('_'):
178             name = name[1:]
179         return name
180
181     def _traverse_one(self, symbol, stype=None):
182         assert isinstance(symbol, SourceSymbol), symbol
183
184         if stype is None:
185             stype = symbol.type
186         if stype == CSYMBOL_TYPE_FUNCTION:
187             try:
188                 return self._create_function(symbol)
189             except SkipError:
190                 return
191         elif stype == CSYMBOL_TYPE_TYPEDEF:
192             return self._create_typedef(symbol)
193         elif stype == CSYMBOL_TYPE_STRUCT:
194             return self._create_struct(symbol)
195         elif stype == CSYMBOL_TYPE_ENUM:
196             return self._create_enum(symbol)
197         elif stype == CSYMBOL_TYPE_OBJECT:
198             return self._create_object(symbol)
199         elif stype == CSYMBOL_TYPE_MEMBER:
200             return self._create_member(symbol)
201         elif stype == CSYMBOL_TYPE_UNION:
202             return self._create_union(symbol)
203         else:
204             raise NotImplementedError(
205                 'Transformer: unhandled symbol: %r' % (symbol, ))
206
207     def _create_enum(self, symbol):
208         members = []
209         for child in symbol.base_type.child_list:
210             name = strip_common_prefix(symbol.ident, child.ident).lower()
211             members.append(Member(name,
212                                   child.const_int,
213                                   child.ident))
214
215         enum_name = self.strip_namespace_object(symbol.ident)
216         enum_name = symbol.ident[-len(enum_name):]
217         enum_name = self._remove_prefix(enum_name)
218         enum = Enum(enum_name, symbol.ident, members)
219         self._names.type_names[symbol.ident] = (None, enum)
220         return enum
221
222     def _create_object(self, symbol):
223         return Member(symbol.ident, symbol.base_type.name,
224                       symbol.ident)
225
226     def _parse_deprecated(self, node, directives):
227         deprecated = directives.get('deprecated', False)
228         if deprecated:
229             deprecated_value = deprecated[0]
230             if ':' in deprecated_value:
231                 # Split out gtk-doc version
232                 (node.deprecated_version, node.deprecated) = \
233                     [x.strip() for x in deprecated_value.split(':', 1)]
234             else:
235                 # No version, just include str
236                 node.deprecated = deprecated_value.strip()
237
238     def _pair_array(self, params, array):
239         if not array.type.length_param_name:
240             return
241         target_name = array.type.length_param_name
242         for i, param in enumerate(params):
243             if param.name == array.type.length_param_name:
244                 array.type.length_param_index = i
245                 return
246         raise ValueError("Unmatched length parameter name %r"\
247                              % (target_name, ))
248
249     def _pair_annotations(self, params):
250         names = {}
251         for param in params:
252             if param.name in names:
253                 raise ValueError("Duplicate parameter name %r"\
254                                      % (param.name, ))
255             names[param.name] = 1
256             if isinstance(param.type, Array):
257                 self._pair_array(params, param)
258
259     def _create_function(self, symbol):
260         directives = symbol.directives()
261         parameters = list(self._create_parameters(
262             symbol.base_type, directives))
263         self._pair_annotations(parameters)
264         return_ = self._create_return(symbol.base_type.base_type,
265                                       directives.get('return', []))
266         name = self._strip_namespace_func(symbol.ident)
267         func = Function(name, return_, parameters, symbol.ident)
268         self._parse_deprecated(func, directives)
269         return func
270
271     def _create_source_type(self, source_type):
272         if source_type is None:
273             return 'None'
274         if source_type.type == CTYPE_VOID:
275             value = 'void'
276         elif source_type.type == CTYPE_BASIC_TYPE:
277             value = source_type.name
278         elif source_type.type == CTYPE_TYPEDEF:
279             value = source_type.name
280         elif source_type.type == CTYPE_ARRAY:
281             return self._create_source_type(source_type.base_type)
282         elif source_type.type == CTYPE_POINTER:
283             value = self._create_source_type(source_type.base_type) + '*'
284         else:
285             print 'TRANSFORMER: Unhandled source type %r' % (
286                 source_type, )
287             value = 'any'
288         return value
289
290     def _create_parameters(self, base_type, options=None):
291         if not options:
292             options = {}
293         for child in base_type.child_list:
294             yield self._create_parameter(
295                 child, options.get(child.ident, []))
296
297     def _create_member(self, symbol):
298         ctype = symbol.base_type.type
299         if (ctype == CTYPE_POINTER and
300             symbol.base_type.base_type.type == CTYPE_FUNCTION):
301             node = self._create_callback(symbol)
302         else:
303             ftype = self._create_type(symbol.base_type)
304             node = Field(symbol.ident, ftype, symbol.ident, symbol.const_int)
305         return node
306
307     def _create_typedef(self, symbol):
308         ctype = symbol.base_type.type
309         if (ctype == CTYPE_POINTER and
310             symbol.base_type.base_type.type == CTYPE_FUNCTION):
311             node = self._create_callback(symbol)
312         elif ctype == CTYPE_STRUCT:
313             node = self._create_typedef_struct(symbol)
314         elif ctype == CTYPE_UNION:
315             node = self._create_typedef_union(symbol)
316         elif ctype == CTYPE_ENUM:
317             return self._create_enum(symbol)
318         elif ctype in (CTYPE_TYPEDEF,
319                        CTYPE_POINTER,
320                        CTYPE_BASIC_TYPE,
321                        CTYPE_VOID):
322             name = self.strip_namespace_object(symbol.ident)
323             if symbol.base_type.name:
324                 target = self.strip_namespace_object(symbol.base_type.name)
325             else:
326                 target = 'none'
327             if name in type_names:
328                 return None
329             return Alias(name, target, ctype=symbol.ident)
330         else:
331             raise NotImplementedError(
332                 "symbol %r of type %s" % (symbol.ident, ctype_name(ctype)))
333         return node
334
335     def _parse_ctype(self, ctype):
336         canonical = type_name_from_ctype(ctype)
337         derefed = canonical.replace('*', '')
338         return derefed
339
340     def _create_type(self, source_type, options=[]):
341         ctype = self._create_source_type(source_type)
342         if ctype == 'va_list':
343             raise SkipError()
344         # FIXME: FILE* should not be skipped, it should be handled
345         #        properly instead
346         elif ctype == 'FILE*':
347             raise SkipError
348         if ctype in self._list_ctypes:
349             if len(options) > 0:
350                 contained_type = self._parse_ctype(options[0])
351                 del options[0]
352             else:
353                 contained_type = None
354             return List(ctype.replace('*', ''),
355                         ctype,
356                         contained_type)
357         if ctype in self._list_ctypes:
358             if len(options) > 0:
359                 key_type = self._parse_ctype(options[0])
360                 value_type = self._parse_ctype(options[1])
361                 del options[0:2]
362             else:
363                 key_type = None
364                 value_type = None
365             return Map(ctype.replace('*', ''),
366                        ctype,
367                        key_type, value_type)
368         if (ctype in default_array_types) or ('array' in options):
369             if 'array' in options:
370                 options.remove('array')
371             derefed = ctype[:-1] # strip the *
372             return Array(ctype,
373                          self._parse_ctype(derefed))
374         resolved_type_name = self._parse_ctype(ctype)
375         return Type(resolved_type_name, ctype)
376
377     def _create_parameter(self, symbol, options):
378         if symbol.type == CSYMBOL_TYPE_ELLIPSIS:
379             ptype = Varargs()
380         else:
381             ptype = self._create_type(symbol.base_type, options)
382         param = Parameter(symbol.ident, ptype)
383         for option in options:
384             if option in ['in-out', 'inout']:
385                 param.direction = 'inout'
386             elif option == 'in':
387                 param.direction = 'in'
388             elif option == 'out':
389                 param.direction = 'out'
390             elif option == 'transfer':
391                 param.transfer = True
392             elif option == 'notransfer':
393                 param.transfer = False
394             elif isinstance(ptype, Array) and option.startswith('length'):
395                 (_, index_param) = option.split('=')
396                 ptype.length_param_name = index_param
397             elif option == 'allow-none':
398                 param.allow_none = True
399             else:
400                 print 'Unhandled parameter annotation option: %r' % (
401                     option, )
402         return param
403
404     def _create_return(self, source_type, options=[]):
405         rtype = self._create_type(source_type, options)
406         rtype = self.resolve_param_type(rtype)
407         return_ = Return(rtype)
408         for option in options:
409             if option == 'transfer':
410                 return_.transfer = True
411             else:
412                 print 'Unhandled parameter annotation option: %r' % (
413                     option, )
414         return return_
415
416     def _create_typedef_struct(self, symbol):
417         name = self.strip_namespace_object(symbol.ident)
418         struct = Struct(name, symbol.ident)
419         self._typedefs_ns[symbol.ident] = struct
420         self._create_struct(symbol)
421         return struct
422
423     def _create_typedef_union(self, symbol):
424         name = self._remove_prefix(symbol.ident)
425         name = self.strip_namespace_object(name)
426         union = Union(name, symbol.ident)
427         self._typedefs_ns[symbol.ident] = union
428         self._create_union(symbol)
429         return union
430
431     def _create_struct(self, symbol):
432         struct = self._typedefs_ns.get(symbol.ident, None)
433         if struct is None:
434             # This is a bit of a hack; really we should try
435             # to resolve through the typedefs to find the real
436             # name
437             if symbol.ident.startswith('_'):
438                 name = symbol.ident[1:]
439             else:
440                 name = symbol.ident
441             name = self.strip_namespace_object(name)
442             name = self.resolve_type_name(name)
443             struct = Struct(name, symbol.ident)
444
445         for child in symbol.base_type.child_list:
446             field = self._traverse_one(child)
447             if field:
448                 struct.fields.append(field)
449
450         return struct
451
452     def _create_union(self, symbol):
453         union = self._typedefs_ns.get(symbol.ident, None)
454         if union is None:
455             # This is a bit of a hack; really we should try
456             # to resolve through the typedefs to find the real
457             # name
458             if symbol.ident.startswith('_'):
459                 name = symbol.ident[1:]
460             else:
461                 name = symbol.ident
462             name = self.strip_namespace_object(name)
463             name = self.resolve_type_name(name)
464             union = Union(name, symbol.ident)
465
466         for child in symbol.base_type.child_list:
467             field = self._traverse_one(child)
468             if field:
469                 union.fields.append(field)
470
471         return union
472
473     def _create_callback(self, symbol):
474         parameters = self._create_parameters(symbol.base_type.base_type)
475         retval = self._create_return(symbol.base_type.base_type.base_type)
476         if symbol.ident.find('_') > 0:
477             name = self._strip_namespace_func(symbol.ident)
478         else:
479             name = self.strip_namespace_object(symbol.ident)
480         return Callback(name, retval, list(parameters), symbol.ident)
481
482     def _typepair_to_str(self, item):
483         nsname, item = item
484         if nsname is None:
485             return item.name
486         return '%s.%s' % (nsname, item.name)
487
488     def _resolve_type_name_1(self, type_name, ctype, names):
489         # First look using the built-in names
490         if ctype:
491             try:
492                 return type_names[ctype]
493             except KeyError, e:
494                 pass
495         try:
496             return type_names[type_name]
497         except KeyError, e:
498             pass
499         type_name = self.strip_namespace_object(type_name)
500         resolved = names.aliases.get(type_name)
501         if resolved:
502             return self._typepair_to_str(resolved)
503         resolved = names.names.get(type_name)
504         if resolved:
505             return self._typepair_to_str(resolved)
506         if ctype:
507             ctype = ctype.replace('*', '')
508             resolved = names.ctypes.get(ctype)
509             if resolved:
510                 return self._typepair_to_str(resolved)
511         resolved = names.type_names.get(type_name)
512         if resolved:
513             return self._typepair_to_str(resolved)
514         raise KeyError("failed to find %r" % (type_name, ))
515
516     def resolve_type_name_full(self, type_name, ctype,
517                                names):
518         try:
519             return self._resolve_type_name_1(type_name, ctype, names)
520         except KeyError, e:
521             try:
522                 return self._resolve_type_name_1(type_name, ctype, self._names)
523             except KeyError, e:
524                 return type_name
525
526     def resolve_type_name(self, type_name, ctype=None):
527         try:
528             return self.resolve_type_name_full(type_name, ctype, self._names)
529         except KeyError, e:
530             return type_name
531
532     def gtypename_to_giname(self, gtname, names):
533         resolved = names.type_names.get(gtname)
534         if resolved:
535             return self._typepair_to_str(resolved)
536         resolved = self._names.type_names.get(gtname)
537         if resolved:
538             return self._typepair_to_str(resolved)
539         raise KeyError("Failed to resolve GType name: %r" % (gtname, ))
540
541     def ctype_of(self, obj):
542         if hasattr(obj, 'ctype'):
543             return obj.ctype
544         elif hasattr(obj, 'symbol'):
545             return obj.symbol
546         else:
547             return None
548
549     def resolve_param_type_full(self, ptype, names):
550         if isinstance(ptype, Node):
551             ptype.name = self.resolve_type_name_full(ptype.name,
552                                                      self.ctype_of(ptype),
553                                                      names)
554             if isinstance(ptype, (Array, List)):
555                 if ptype.element_type is not None:
556                     ptype.element_type = \
557                         self.resolve_param_type_full(ptype.element_type, names)
558             if isinstance(ptype, Map):
559                 if ptype.key_type is not None:
560                     ptype.key_type = \
561                         self.resolve_param_type_full(ptype.key_type, names)
562                     ptype.value_type = \
563                         self.resolve_param_type_full(ptype.value_type, names)
564         elif isinstance(ptype, basestring):
565             return self.resolve_type_name_full(ptype, None, names)
566         else:
567             raise AssertionError("Unhandled param: %r" % (ptype, ))
568         return ptype
569
570     def resolve_param_type(self, ptype):
571         try:
572             return self.resolve_param_type_full(ptype, self._names)
573         except KeyError, e:
574             return ptype
575
576     def follow_aliases(self, type_name, names):
577         while True:
578             resolved = names.aliases.get(type_name)
579             if resolved:
580                 (ns, alias) = resolved
581                 type_name = alias.target
582             else:
583                 break
584         return type_name