Write out list and map types, parse 'array' annotation
[gnome.gobject-introspection] / giscanner / transformer.py
1 # -*- Mode: Python -*-
2 # GObject-Introspection - a framework for introspecting GObject libraries
3 # Copyright (C) 2008  Johan Dahlin
4 #
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; either version 2
8 # of the License, or (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
18 # 02110-1301, USA.
19 #
20
21 import os
22
23 from giscanner.ast import (Callback, Enum, Function, Namespace, Member,
24                            Parameter, Return, Array, Struct, Field,
25                            Type, Alias, Interface, Class, Node, Union,
26                            List, Map, type_name_from_ctype, type_names,
27                            default_array_types)
28 from giscanner.config import DATADIR
29 from .glibast import GLibBoxed
30 from giscanner.sourcescanner import (
31     SourceSymbol, ctype_name, CTYPE_POINTER,
32     CTYPE_BASIC_TYPE, CTYPE_UNION, CTYPE_ARRAY, CTYPE_TYPEDEF,
33     CTYPE_VOID, CTYPE_ENUM, CTYPE_FUNCTION, CTYPE_STRUCT,
34     CSYMBOL_TYPE_FUNCTION, CSYMBOL_TYPE_TYPEDEF, CSYMBOL_TYPE_STRUCT,
35     CSYMBOL_TYPE_ENUM, CSYMBOL_TYPE_UNION, CSYMBOL_TYPE_OBJECT,
36     CSYMBOL_TYPE_MEMBER)
37 from .odict import odict
38 from .utils import strip_common_prefix, to_underscores
39
40 _xdg_data_dirs = [x for x in os.environ.get('XDG_DATA_DIRS', '').split(':') \
41                       + [DATADIR, '/usr/share'] if x]
42
43
44 class SkipError(Exception):
45     pass
46
47
48 class Names(object):
49     names = property(lambda self: self._names)
50     aliases = property(lambda self: self._aliases)
51     type_names = property(lambda self: self._type_names)
52     ctypes = property(lambda self: self._ctypes)
53
54     def __init__(self):
55         super(Names, self).__init__()
56         self._names = odict() # Maps from GIName -> (namespace, node)
57         self._aliases = {} # Maps from GIName -> GIName
58         self._type_names = {} # Maps from GTName -> (namespace, node)
59         self._ctypes = {} # Maps from CType -> (namespace, node)
60
61
62 class Transformer(object):
63
64     def __init__(self, generator, namespace_name):
65         self.generator = generator
66         self._namespace = Namespace(namespace_name)
67         self._names = Names()
68         self._typedefs_ns = {}
69         self._strip_prefix = ''
70         self._includes = set()
71         self._includepaths = []
72
73         self._list_ctypes = []
74         self._map_ctypes = []
75
76     def get_names(self):
77         return self._names
78
79     def get_includes(self):
80         return self._includes
81
82     def set_container_types(self, list_ctypes, map_ctypes):
83         self._list_ctypes = list_ctypes
84         self._map_ctypes = map_ctypes
85
86     def set_strip_prefix(self, strip_prefix):
87         self._strip_prefix = strip_prefix
88
89     def parse(self):
90         nodes = []
91         for symbol in self.generator.get_symbols():
92             node = self._traverse_one(symbol)
93             self._add_node(node)
94         return self._namespace
95
96     def register_include(self, filename):
97         (path, suffix) = os.path.splitext(filename)
98         name = os.path.basename(path)
99         if name in self._includes:
100             return
101         if suffix == '':
102             suffix = '.gir'
103             filename = path + suffix
104         if suffix == '.gir':
105             source = filename
106             if not os.path.exists(filename):
107                 searchdirs = [os.path.join(d, 'gir') for d \
108                                   in _xdg_data_dirs]
109                 searchdirs.extend(self._includepaths)
110                 source = None
111                 for d in searchdirs:
112                     source = os.path.join(d, filename)
113                     if os.path.exists(source):
114                         break
115                     source = None
116             if not source:
117                 raise ValueError("Couldn't find include %r (search path: %r)"\
118                                      % (filename, searchdirs))
119             d = os.path.dirname(source)
120             if d not in self._includepaths:
121                 self._includepaths.append(d)
122             self._includes.add(name)
123             from .girparser import GIRParser
124             parser = GIRParser(source)
125         else:
126             raise NotImplementedError(filename)
127         for include in parser.get_includes():
128             self.register_include(include)
129         nsname = parser.get_namespace_name()
130         for node in parser.get_nodes():
131             if isinstance(node, Alias):
132                 self._names.aliases[node.name] = (nsname, node)
133             elif isinstance(node, (GLibBoxed, Interface, Class)):
134                 self._names.type_names[node.type_name] = (nsname, node)
135             self._names.names[node.name] = (nsname, node)
136             if hasattr(node, 'ctype'):
137                 self._names.ctypes[node.ctype] = (nsname, node)
138             elif hasattr(node, 'symbol'):
139                 self._names.ctypes[node.symbol] = (nsname, node)
140
141     def strip_namespace_object(self, name):
142         prefix = self._namespace.name.lower()
143         if len(name) > len(prefix) and name.lower().startswith(prefix):
144             return name[len(prefix):]
145         return self._remove_prefix(name)
146
147     # Private
148
149     def _add_node(self, node):
150         if node is None:
151             return
152         if node.name.startswith('_'):
153             return
154         self._namespace.nodes.append(node)
155         self._names.names[node.name] = (None, node)
156
157     def _strip_namespace_func(self, name):
158         prefix = self._namespace.name.lower() + '_'
159         if name.lower().startswith(prefix):
160             name = name[len(prefix):]
161         else:
162             prefix = to_underscores(self._namespace.name).lower() + '_'
163             if name.lower().startswith(prefix):
164                 name = name[len(prefix):]
165         return self._remove_prefix(name, isfunction=True)
166
167     def _remove_prefix(self, name, isfunction=False):
168         # when --strip-prefix=g:
169         #   GHashTable -> HashTable
170         #   g_hash_table_new -> hash_table_new
171         prefix = self._strip_prefix.lower()
172         if isfunction:
173             prefix += '_'
174         if name.lower().startswith(prefix):
175             name = name[len(prefix):]
176
177         while name.startswith('_'):
178             name = name[1:]
179         return name
180
181     def _traverse_one(self, symbol, stype=None):
182         assert isinstance(symbol, SourceSymbol), symbol
183
184         if stype is None:
185             stype = symbol.type
186         if stype == CSYMBOL_TYPE_FUNCTION:
187             try:
188                 return self._create_function(symbol)
189             except SkipError:
190                 return
191         elif stype == CSYMBOL_TYPE_TYPEDEF:
192             return self._create_typedef(symbol)
193         elif stype == CSYMBOL_TYPE_STRUCT:
194             return self._create_struct(symbol)
195         elif stype == CSYMBOL_TYPE_ENUM:
196             return self._create_enum(symbol)
197         elif stype == CSYMBOL_TYPE_OBJECT:
198             return self._create_object(symbol)
199         elif stype == CSYMBOL_TYPE_MEMBER:
200             return self._create_member(symbol)
201         elif stype == CSYMBOL_TYPE_UNION:
202             return self._create_union(symbol)
203         else:
204             raise NotImplementedError(
205                 'Transformer: unhandled symbol: %r' % (symbol, ))
206
207     def _create_enum(self, symbol):
208         members = []
209         for child in symbol.base_type.child_list:
210             name = strip_common_prefix(symbol.ident, child.ident).lower()
211             members.append(Member(name,
212                                   child.const_int,
213                                   child.ident))
214
215         enum_name = self.strip_namespace_object(symbol.ident)
216         enum_name = symbol.ident[-len(enum_name):]
217         enum_name = self._remove_prefix(enum_name)
218         enum = Enum(enum_name, symbol.ident, members)
219         self._names.type_names[symbol.ident] = (None, enum)
220         return enum
221
222     def _create_object(self, symbol):
223         return Member(symbol.ident, symbol.base_type.name,
224                       symbol.ident)
225
226     def _parse_deprecated(self, node, directives):
227         deprecated = directives.get('deprecated', False)
228         if deprecated:
229             deprecated_value = deprecated[0]
230             if ':' in deprecated_value:
231                 # Split out gtk-doc version
232                 (node.deprecated_version, node.deprecated) = \
233                     [x.strip() for x in deprecated_value.split(':', 1)]
234             else:
235                 # No version, just include str
236                 node.deprecated = deprecated_value.strip()
237
238     def _create_function(self, symbol):
239         directives = symbol.directives()
240         parameters = list(self._create_parameters(
241             symbol.base_type, directives))
242         return_ = self._create_return(symbol.base_type.base_type,
243                                       directives.get('return', []))
244         name = self._strip_namespace_func(symbol.ident)
245         func = Function(name, return_, parameters, symbol.ident)
246         self._parse_deprecated(func, directives)
247         return func
248
249     def _create_source_type(self, source_type):
250         if source_type is None:
251             return 'None'
252         if source_type.type == CTYPE_VOID:
253             value = 'void'
254         elif source_type.type == CTYPE_BASIC_TYPE:
255             value = source_type.name
256         elif source_type.type == CTYPE_TYPEDEF:
257             value = source_type.name
258         elif source_type.type == CTYPE_ARRAY:
259             return self._create_source_type(source_type.base_type)
260         elif source_type.type == CTYPE_POINTER:
261             value = self._create_source_type(source_type.base_type) + '*'
262         else:
263             print 'TRANSFORMER: Unhandled source type %r' % (
264                 source_type, )
265             value = 'any'
266         return value
267
268     def _create_parameters(self, base_type, options=None):
269         if not options:
270             options = {}
271         for child in base_type.child_list:
272             yield self._create_parameter(
273                 child, options.get(child.ident, []))
274
275     def _create_member(self, symbol):
276         ctype = symbol.base_type.type
277         if (ctype == CTYPE_POINTER and
278             symbol.base_type.base_type.type == CTYPE_FUNCTION):
279             node = self._create_callback(symbol)
280         else:
281             ftype = self._create_type(symbol.base_type)
282             node = Field(symbol.ident, ftype, symbol.ident)
283         return node
284
285     def _create_typedef(self, symbol):
286         ctype = symbol.base_type.type
287         if (ctype == CTYPE_POINTER and
288             symbol.base_type.base_type.type == CTYPE_FUNCTION):
289             node = self._create_callback(symbol)
290         elif ctype == CTYPE_STRUCT:
291             node = self._create_typedef_struct(symbol)
292         elif ctype == CTYPE_UNION:
293             node = self._create_typedef_union(symbol)
294         elif ctype == CTYPE_ENUM:
295             return self._create_enum(symbol)
296         elif ctype in (CTYPE_TYPEDEF,
297                        CTYPE_POINTER,
298                        CTYPE_BASIC_TYPE,
299                        CTYPE_VOID):
300             name = self.strip_namespace_object(symbol.ident)
301             if symbol.base_type.name:
302                 target = self.strip_namespace_object(symbol.base_type.name)
303             else:
304                 target = 'none'
305             if name in type_names:
306                 return None
307             return Alias(name, target, ctype=symbol.ident)
308         else:
309             raise NotImplementedError(
310                 "symbol %r of type %s" % (symbol.ident, ctype_name(ctype)))
311         return node
312
313     def _parse_and_resolve_ctype(self, ctype):
314         canonical = type_name_from_ctype(ctype)
315         derefed = canonical.replace('*', '')
316         return self.resolve_type_name(derefed)
317
318     def _create_type(self, source_type, options=[]):
319         ctype = self._create_source_type(source_type)
320         if ctype == 'va_list':
321             raise SkipError
322         # FIXME: FILE* should not be skipped, it should be handled
323         #        properly instead
324         elif ctype == 'FILE*':
325             raise SkipError
326         if ctype in self._list_ctypes:
327             if len(options) > 0:
328                 contained_type = self._parse_and_resolve_ctype(options[0])
329                 del options[0]
330             else:
331                 contained_type = None
332             return List(ctype.replace('*', ''),
333                         ctype,
334                         contained_type)
335         if ctype in self._list_ctypes:
336             if len(options) > 0:
337                 key_type = self._parse_and_resolve_ctype(options[0])
338                 value_type = self._parse_and_resolve_ctype(options[1])
339                 del options[0:2]
340             else:
341                 key_type = None
342                 value_type = None
343             return Map(ctype.replace('*', ''),
344                        ctype,
345                        key_type, value_type)
346         if (ctype in default_array_types) or ('array' in options):
347             if 'array' in options:
348                 options.remove('array')
349             derefed = ctype[:-1] # strip the *
350             return Array(None, ctype,
351                          type_name_from_ctype(derefed))
352         resolved_type_name = self._parse_and_resolve_ctype(ctype)
353         return Type(resolved_type_name, ctype)
354
355     def _create_parameter(self, symbol, options):
356         ptype = self._create_type(symbol.base_type, options)
357         param = Parameter(symbol.ident, ptype)
358         for option in options:
359             if option in ['in-out', 'inout']:
360                 param.direction = 'inout'
361             elif option == 'in':
362                 param.direction = 'in'
363             elif option == 'out':
364                 param.direction = 'out'
365             elif option == 'transfer':
366                 param.transfer = True
367             elif option == 'notransfer':
368                 param.transfer = False
369             elif option == 'allow-none':
370                 param.allow_none = True
371             else:
372                 print 'Unhandled parameter annotation option: %r' % (
373                     option, )
374         return param
375
376     def _create_return(self, source_type, options=[]):
377         rtype = self._create_type(source_type, options)
378         rtype = self.resolve_param_type(rtype)
379         return_ = Return(rtype)
380         for option in options:
381             if option == 'transfer':
382                 return_.transfer = True
383             else:
384                 print 'Unhandled parameter annotation option: %r' % (
385                     option, )
386         return return_
387
388     def _create_typedef_struct(self, symbol):
389         name = self.strip_namespace_object(symbol.ident)
390         struct = Struct(name, symbol.ident)
391         self._typedefs_ns[symbol.ident] = struct
392         return struct
393
394     def _create_typedef_union(self, symbol):
395         name = self._remove_prefix(symbol.ident)
396         name = self.strip_namespace_object(name)
397         union = Union(name, symbol.ident)
398         self._typedefs_ns[symbol.ident] = union
399         return union
400
401     def _create_struct(self, symbol):
402         struct = self._typedefs_ns.get(symbol.ident, None)
403         if struct is None:
404             # This is a bit of a hack; really we should try
405             # to resolve through the typedefs to find the real
406             # name
407             if symbol.ident.startswith('_'):
408                 name = symbol.ident[1:]
409             else:
410                 name = symbol.ident
411             name = self.strip_namespace_object(name)
412             name = self.resolve_type_name(name)
413             struct = Struct(name, symbol.ident)
414
415         for child in symbol.base_type.child_list:
416             field = self._traverse_one(child)
417             if field:
418                 struct.fields.append(field)
419
420         return struct
421
422     def _create_union(self, symbol):
423         union = self._typedefs_ns.get(symbol.ident, None)
424         if union is None:
425             # This is a bit of a hack; really we should try
426             # to resolve through the typedefs to find the real
427             # name
428             if symbol.ident.startswith('_'):
429                 name = symbol.ident[1:]
430             else:
431                 name = symbol.ident
432             name = self.strip_namespace_object(name)
433             name = self.resolve_type_name(name)
434             union = Union(name, symbol.ident)
435
436         for child in symbol.base_type.child_list:
437             field = self._traverse_one(child)
438             if field:
439                 union.fields.append(field)
440
441         return union
442
443     def _create_callback(self, symbol):
444         parameters = self._create_parameters(symbol.base_type.base_type)
445         retval = self._create_return(symbol.base_type.base_type.base_type)
446         if symbol.ident.find('_') > 0:
447             name = self._strip_namespace_func(symbol.ident)
448         else:
449             name = self.strip_namespace_object(symbol.ident)
450         return Callback(name, retval, list(parameters), symbol.ident)
451
452     def _typepair_to_str(self, item):
453         nsname, item = item
454         if nsname is None:
455             return item.name
456         return '%s.%s' % (nsname, item.name)
457
458     def _resolve_type_name_1(self, type_name, ctype, names):
459         # First look using the built-in names
460         if ctype:
461             try:
462                 return type_names[ctype]
463             except KeyError, e:
464                 pass
465         try:
466             return type_names[type_name]
467         except KeyError, e:
468             pass
469         type_name = self.strip_namespace_object(type_name)
470         resolved = names.aliases.get(type_name)
471         if resolved:
472             return self._typepair_to_str(resolved)
473         resolved = names.names.get(type_name)
474         if resolved:
475             return self._typepair_to_str(resolved)
476         if ctype:
477             ctype = ctype.replace('*', '')
478             resolved = names.ctypes.get(ctype)
479             if resolved:
480                 return self._typepair_to_str(resolved)
481         resolved = names.type_names.get(type_name)
482         if resolved:
483             return self._typepair_to_str(resolved)
484         raise KeyError("failed to find %r" % (type_name, ))
485
486     def resolve_type_name_full(self, type_name, ctype,
487                                names):
488         try:
489             return self._resolve_type_name_1(type_name, ctype, names)
490         except KeyError, e:
491             try:
492                 return self._resolve_type_name_1(type_name, ctype, self._names)
493             except KeyError, e:
494                 return type_name
495
496     def resolve_type_name(self, type_name, ctype=None):
497         try:
498             return self.resolve_type_name_full(type_name, ctype, self._names)
499         except KeyError, e:
500             return type_name
501
502     def gtypename_to_giname(self, gtname, names):
503         resolved = names.type_names.get(gtname)
504         if resolved:
505             return self._typepair_to_str(resolved)
506         resolved = self._names.type_names.get(gtname)
507         if resolved:
508             return self._typepair_to_str(resolved)
509         raise KeyError("Failed to resolve GType name: %r" % (gtname, ))
510
511     def ctype_of(self, obj):
512         if hasattr(obj, 'ctype'):
513             return obj.ctype
514         elif hasattr(obj, 'symbol'):
515             return obj.symbol
516         else:
517             return None
518
519     def resolve_param_type_full(self, ptype, names):
520         if isinstance(ptype, Array):
521             ptype.element_type = \
522                 self.resolve_param_type_full(ptype.element_type, names)
523         elif isinstance(ptype, Node):
524             ptype.name = self.resolve_type_name_full(ptype.name,
525                                                      self.ctype_of(ptype),
526                                                      names)
527         elif isinstance(ptype, basestring):
528             return self.resolve_type_name_full(ptype, None, names)
529         else:
530             raise AssertionError("Unhandled param: %r" % (ptype, ))
531         return ptype
532
533     def resolve_param_type(self, ptype):
534         try:
535             return self.resolve_param_type_full(ptype, self._names)
536         except KeyError, e:
537             return ptype
538
539     def follow_aliases(self, type_name, names):
540         while True:
541             resolved = names.aliases.get(type_name)
542             if resolved:
543                 (ns, alias) = resolved
544                 type_name = alias.target
545             else:
546                 break
547         return type_name